Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/code_spell_ignore.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
rouge
Rouge
ROUGE
32 changes: 32 additions & 0 deletions evals/evaluation/rag_eval/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# CRUD-RAG
[CRUD-RAG](https://arxiv.org/abs/2401.17043) is a Chinese benchmark for RAG (Retrieval-Augmented Generation) system. This example utilize CRUD-RAG for evaluating the RAG system.

## Prerequisite

### Environment
```bash
pip install -r requirements.txt
```

### Prepare Dataset
We use the evaluation dataset from [CRUD-RAG](https://github.com/IAAR-Shanghai/CRUD_RAG) repo, use the below command to prepare the dataset.
```bash
git clone https://github.com/IAAR-Shanghai/CRUD_RAG
mkdir data/
cp CRUD_RAG/data/crud_split/split_merged.json data/
cp -r CRUD_RAG/data/80000_docs/ data/
python examples/process_crud_dataset.py
```

### Launch Service of RAG System
Please refer to this [guide](https://github.com/opea-project/GenAIExamples/blob/main/ChatQnA/README.md) to launch the service of RAG system.

## Evaluation
Use below command to run the evaluation, please note that for the first run, argument `--ingest_docs` should be added in the command to ingest the documents into the vector database, while for the subsequent run, this argument should be omitted.
```bash
cd examples
python main.py --dataset_path ../data/split_merged.json --docs_path ../data/80000_docs --ingest_docs
```

## Acknowledgements
This example is mostly adapted from [CRUD-RAG](https://github.com/IAAR-Shanghai/CRUD_RAG) repo, we thank the authors for their great work!
10 changes: 10 additions & 0 deletions evals/evaluation/rag_eval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

#

from .evaluator import Evaluator

__all__ = [Evaluator]
215 changes: 215 additions & 0 deletions evals/evaluation/rag_eval/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import datetime
import json
import os

import requests
from tqdm import tqdm

from evals.metrics import bleu_score, rougeL_score

from .metrics import LLM_score


class Evaluator:
def __init__(
self,
dataset: list[dict],
output_path: str,
task: str,
) -> None:
"""Args:
dataset (list[dict]): The dataset for evaluation.
output_path (str): The path to save results.
task (str): Task to evaluate.
"""
self.task = task
self.output_path = output_path
self.dataset = dataset

@staticmethod
def ingest_docs(documents_path: str, database_endpoint: str):
"""Args:
documents_path (str): The path to documents.
database_endpoint (str): URL of database.
"""
files = []
if os.path.isfile(documents_path):
files.append(documents_path)
elif os.path.isdir(documents_path):
for root, dirs, files_ in os.walk(documents_path):
files += [os.path.join(root, f) for f in files_]
for file in tqdm(files):
file_obj = open(file, mode="rb")
response = requests.post(database_endpoint, files={"files": file_obj})
if response.ok:
print(f"Successfully ingested {file}.")
else:
print(f"Failed to ingest {file}.")
file_obj.close()

def get_ground_truth_text(self, data: dict):
raise NotImplementedError("Depends on the specific dataset.")

def get_query(self, data: dict):
raise NotImplementedError("Depends on the specific dataset.")

def get_document(self, data: dict):
raise NotImplementedError("Depends on the specific dataset.")

def scoring(self, data: dict, llm_endpoint: str = None) -> dict:
generated_text = data["generated_text"]
ground_truth_text = self.get_ground_truth_text(data)
data["ground_truth_text"] = ground_truth_text

bleu_avg, bleu1, bleu2, bleu3, bleu4 = bleu_score(generated_text, ground_truth_text)

return {
"metrics": {
"bleu-avg": bleu_avg or 0.0,
"bleu-1": bleu1 or 0.0,
"bleu-2": bleu2 or 0.0,
"bleu-3": bleu3 or 0.0,
"bleu-4": bleu4 or 0.0,
"rouge-L": rougeL_score(generated_text, ground_truth_text) or 0.0,
"LLM-score": LLM_score(generated_text, ground_truth_text, llm_endpoint) or 0.0,
"length": len(generated_text),
},
"log": {
"generated_text": generated_text,
"ground_truth_text": ground_truth_text,
"evaluateDatetime": str(datetime.datetime.now()),
},
"valid": len(generated_text.strip()) != 0,
}

def compute_overall(self, results: list[dict]) -> dict:
overall = {
"bleu-avg": 0,
"bleu-1": 0,
"bleu-2": 0,
"bleu-3": 0,
"bleu-4": 0,
"rouge-L": 0,
"LLM-score": 0.0,
"length": 0,
}

for result in results:
overall = {key: overall[key] + result["metrics"][key] for key in overall.keys()}

overall_save = {f"avg. {key}": value / len(results) for key, value in overall.items()}

overall_save["num"] = len(results)

return overall_save

def save_output(self, output: dict) -> None:
"""Save evaluation results."""
with open(self.output_path, "w", encoding="utf-8") as f:
json.dump(output, f, ensure_ascii=False, indent=4)

def read_output(self) -> dict:
with open(self.output_path) as f:
return json.load(f)

def remove_invalid(self, results: list[dict]) -> list[dict]:
"""Remove invalid results from the list and return the cleaned results."""
return [result for result in results if result["valid"]]

def send_request(self, data, arguments):
service_url = arguments.service_url
headers = {"Content-Type": "application/json"}
json_data = {}
query = self.get_query(data)
json_data["messages"] = query
json_data["stream"] = False
json_data["temperature"] = arguments.temperature
json_data["max_new_tokens"] = arguments.max_new_tokens
json_data = json.dumps(json_data)
response = requests.post(service_url, data=json_data, headers=headers)
if response.ok:
return response.json()["choices"][0]["message"]["content"]
else:
print(f"Request for pipeline failed due to {response.text}.")
return ""

def get_retrieved_documents(self, data, arguments):
query = self.get_query(data)
data = {"text": query}
headers = {"Content-Type": "application/json"}
response = requests.post(arguments.embedding_endpoint, data=json.dumps(data), headers=headers)
if response.ok:
embedding = response.json()["embedding"]
else:
print(f"Request for embedding failed due to {response.text}.")
return []
data = {
"text": query,
"embedding": embedding,
"search_typ": "similarity",
"k": 4,
"fetch_k": 20,
"lambda_mult": 0.5,
}
response = requests.post(arguments.retrieval_endpoint, data=json.dumps(data), headers=headers)
if response.ok:
retrieved_documents = response.json()["retrieved_docs"]
return [doc["text"] for doc in retrieved_documents]
else:
print(f"Request for retrieval failed due to {response.text}.")
return []

def scoring_retrieval(self, data, retrieved_documents):
ground_truth_documents = self.get_document(data)

def evaluate(self, arguments, sort=True, show_progress_bar=False, contain_original_data=False):
"""Run a complete evaluation.

Args:
arguments: Arguments.
sort (bool): Whether to sort the results by id.
show_progress_bar (bool): Whether to display a progress bar.
contain_original_data (bool): Whether to include original data in the results for debugging.

Returns:
dict: Output dictionary contains fields such as: overall, results, etc.
"""
if os.path.exists(self.output_path): # Resume evaluation
results = self.read_output().get("results", [])
results = self.remove_invalid(results)
saved_ids = [result["id"] for result in results]
else:
results = []
saved_ids = []

for data in tqdm(self.dataset) if show_progress_bar else self.dataset:
if data["ID"] in saved_ids:
continue # Skip results that have already been evaluated and are valid
try:
retrieved_documents = self.get_retrieved_documents(data, arguments)
data["retrieved_documents"] = retrieved_documents
generated_text = self.send_request(data, arguments)
data["generated_text"] = generated_text
result = {"id": data["ID"], **self.scoring(data, arguments.llm_endpoint)}
if contain_original_data:
result["original_data"] = data
results.append(result)
except Exception as e:
print(repr(e))

results = sorted(results, key=lambda x: x["id"]) if sort else results
valid_results = self.remove_invalid(results)

try:
overall = self.compute_overall(valid_results) if len(valid_results) > 0 else {}
except Exception as e:
print(repr(e))
overall = dict()

output = {"overall": overall, "results": results}
self.save_output(output)
print(f"Output saved to {self.output_path}!")
return output
133 changes: 133 additions & 0 deletions evals/evaluation/rag_eval/examples/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


import argparse
import json
import os

from evals.evaluation.rag_eval import Evaluator


class CRUD_Evaluator(Evaluator):
def get_ground_truth_text(self, data: dict):
if self.task == "summarization":
ground_truth_text = data["summary"]
elif self.task == "question_answering":
ground_truth_text = data["answers"]
elif self.task == "continuation":
ground_truth_text = data["continuing"]
elif self.task == "hallucinated_modified":
ground_truth_text = data["hallucinatedMod"]
else:
raise NotImplementedError(
f"Unknown task {self.task}, only support "
"summarization, question_answering, continuation and hallucinated_modified."
)
return ground_truth_text

def get_query(self, data: dict):
if self.task == "summarization":
query = data["text"]
elif self.task == "question_answering":
query = data["questions"]
elif self.task == "continuation":
query = data["beginning"]
elif self.task == "hallucinated_modified":
query = data["newsBeginning"]
else:
raise NotImplementedError(
f"Unknown task {self.task}, only support "
"summarization, question_answering, continuation and hallucinated_modified."
)
return query

def get_document(self, data: dict):
if self.task == "summarization":
document = data["text"]
elif self.task == "question_answering":
document = data["news1"]
elif self.task == "continuation":
document = data["beginning"]
elif self.task == "hallucinated_modified":
document = data["newsBeginning"]
else:
raise NotImplementedError(
f"Unknown task {self.task}, only support "
"summarization, question_answering, continuation and hallucinated_modified."
)
return document


def args_parser():
parser = argparse.ArgumentParser()

parser.add_argument(
"--service_url", type=str, default="http://localhost:8888/v1/chatqna", help="Service URL address."
)
parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save evaluation results.")
parser.add_argument(
"--temperature", type=float, default=0.1, help="Controls the randomness of the model's text generation"
)
parser.add_argument(
"--max_new_tokens", type=int, default=1280, help="Maximum number of new tokens to be generated by the model"
)
parser.add_argument("--dataset_path", default="../data/split_merged.json", help="Path to the dataset")
parser.add_argument("--docs_path", default="../data/80000_docs", help="Path to the retrieval documents")

# Retriever related options
parser.add_argument("--tasks", default=["question_answering"], nargs="+", help="Task to perform")
parser.add_argument("--ingest_docs", action="store_true", help="Whether to ingest documents to vector database")
parser.add_argument(
"--database_endpoint", type=str, default="http://localhost:6007/v1/dataprep", help="Service URL address."
)
parser.add_argument(
"--embedding_endpoint", type=str, default="http://localhost:6000/v1/embeddings", help="Service URL address."
)
parser.add_argument(
"--retrieval_endpoint", type=str, default="http://localhost:7000/v1/retrieval", help="Service URL address."
)
parser.add_argument(
"--llm_endpoint", type=str, default="http://localhost:9009/generate", help="Service URL address."
)
parser.add_argument(
"--show_progress_bar", action="store", default=True, type=bool, help="Whether to show a progress bar"
)
parser.add_argument("--contain_original_data", action="store_true", help="Whether to contain original data")

args = parser.parse_args()
return args


def main():
args = args_parser()
if os.path.isfile(args.dataset_path):
with open(args.dataset_path) as f:
all_datasets = json.load(f)
else:
raise FileNotFoundError(f"Evaluation dataset file {args.dataset_path} not exist.")
os.makedirs(args.output_dir, exist_ok=True)
for task in args.tasks:
if task == "question_answering":
dataset = all_datasets["questanswer_1doc"]
elif task == "summarization":
dataset = all_datasets["event_summary"]
else:
raise NotImplementedError(
f"Unknown task {task}, only support "
"summarization, question_answering, continuation and hallucinated_modified."
)
output_save_path = os.path.join(args.output_dir, f"{task}.json")
evaluator = CRUD_Evaluator(dataset, output_save_path, task)
if args.ingest_docs:
CRUD_Evaluator.ingest_docs(args.docs_path, args.database_endpoint)
results = evaluator.evaluate(
args, show_progress_bar=args.show_progress_bar, contain_original_data=args.contain_original_data
)
print(f"Evaluation results of task {task} saved to {output_save_path}.")


if __name__ == "__main__":
main()
Loading