|
| 1 | +# Copyright 2021-2024 VMware, Inc. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +import csv |
| 4 | +import logging |
| 5 | +import pathlib |
| 6 | +import re |
| 7 | + |
| 8 | +import nltk |
| 9 | +from config import DOCUMENTS_CSV_FILE_LOCATION |
| 10 | +from config import EMBEDDINGS_PKL_FILE_LOCATION |
| 11 | +from nltk.corpus import stopwords |
| 12 | +from nltk.stem import WordNetLemmatizer |
| 13 | +from sentence_transformers import SentenceTransformer |
| 14 | +from vdk.api.job_input import IJobInput |
| 15 | + |
| 16 | +log = logging.getLogger(__name__) |
| 17 | + |
| 18 | + |
| 19 | +def clean_text(text): |
| 20 | + """ |
| 21 | + Prepares text for NLP tasks (embedding and RAG) by standardizing its form. It focuses on retaining |
| 22 | + meaningful words and achieving consistency in their representation. This involves |
| 23 | + converting to lowercase (uniformity), removing punctuation and stopwords |
| 24 | + (focusing on relevant words), and lemmatization (reducing words to their base form). |
| 25 | + Such preprocessing is crucial for effective NLP analysis. |
| 26 | +
|
| 27 | + :param text: A string containing the text to be processed. |
| 28 | + :return: The processed text as a string. |
| 29 | + """ |
| 30 | + text = text.lower() |
| 31 | + # remove punctuation and special characters |
| 32 | + text = re.sub(r"[^\w\s]", "", text) |
| 33 | + # remove stopwords and lemmatize |
| 34 | + stop_words = set(stopwords.words("english")) |
| 35 | + lemmatizer = WordNetLemmatizer() |
| 36 | + text = " ".join( |
| 37 | + [lemmatizer.lemmatize(word) for word in text.split() if word not in stop_words] |
| 38 | + ) |
| 39 | + return text |
| 40 | + |
| 41 | + |
| 42 | +def load_and_clean_documents(filename): |
| 43 | + cleaned_documents = [] |
| 44 | + with open(filename, encoding="utf-8") as file: |
| 45 | + reader = csv.reader(file) |
| 46 | + next(reader, None) |
| 47 | + for row in reader: |
| 48 | + if row: |
| 49 | + cleaned_text = clean_text(row[0]) |
| 50 | + cleaned_documents.append([cleaned_text]) |
| 51 | + return cleaned_documents |
| 52 | + |
| 53 | + |
| 54 | +def save_cleaned_documents(cleaned_documents, output_file): |
| 55 | + with open(output_file, mode="w", newline="", encoding="utf-8") as file: |
| 56 | + writer = csv.writer(file) |
| 57 | + writer.writerows(cleaned_documents) |
| 58 | + |
| 59 | + |
| 60 | +def embed_documents_in_batches(documents): |
| 61 | + # the model card: https://huggingface.co/sentence-transformers/all-mpnet-base-v2 |
| 62 | + model = SentenceTransformer("all-mpnet-base-v2") |
| 63 | + total = len(documents) |
| 64 | + log.info(f"total: {total}") |
| 65 | + embeddings = [] |
| 66 | + for start_index in range(0, total): |
| 67 | + # the resources are not enough to batch 2 documents at a time, so the batch = 1 doc |
| 68 | + batch = documents[start_index] |
| 69 | + log.info(f"BATCH: {len(batch)}.") |
| 70 | + embeddings.extend(model.encode(batch, show_progress_bar=True)) |
| 71 | + return embeddings |
| 72 | + |
| 73 | + |
| 74 | +def run(job_input: IJobInput): |
| 75 | + log.info(f"Starting job step {__name__}") |
| 76 | + |
| 77 | + input_csv = DOCUMENTS_CSV_FILE_LOCATION |
| 78 | + # output_cleaned_csv = 'documents_cleaned.csv' |
| 79 | + data_job_dir = pathlib.Path(job_input.get_job_directory()) |
| 80 | + output_embeddings = data_job_dir / EMBEDDINGS_PKL_FILE_LOCATION |
| 81 | + |
| 82 | + # create a temporary (until the end of the job execution) dir with |
| 83 | + # write permissions to store the relevant nltk dependencies |
| 84 | + temp_dir = job_input.get_temporary_write_directory() |
| 85 | + nltk_data_path = temp_dir / "nltk_data" |
| 86 | + nltk_data_path.mkdir(exist_ok=True) |
| 87 | + nltk.data.path.append(str(nltk_data_path)) |
| 88 | + |
| 89 | + nltk.download("stopwords", download_dir=str(nltk_data_path)) |
| 90 | + nltk.download("wordnet", download_dir=str(nltk_data_path)) |
| 91 | + |
| 92 | + cleaned_documents = load_and_clean_documents(input_csv) |
| 93 | + if cleaned_documents: |
| 94 | + log.info( |
| 95 | + f"{len(cleaned_documents)} documents loaded and cleaned for embedding." |
| 96 | + ) |
| 97 | + # save_cleaned_documents(cleaned_documents, output_cleaned_csv) |
| 98 | + # log.info(f"Cleaned documents saved to {output_cleaned_csv}") |
| 99 | + embeddings = embed_documents_in_batches(cleaned_documents) |
| 100 | + with open(output_embeddings, "wb") as file: |
| 101 | + import pickle |
| 102 | + |
| 103 | + pickle.dump(embeddings, file) |
| 104 | + log.info(f"Embeddings saved to {output_embeddings}") |
0 commit comments