From 449b400a063d53b034db84f7e19e09ba61d98899 Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Wed, 27 Dec 2023 15:55:40 +0800 Subject: [PATCH 01/22] zhishu completion function --- evals/completion_fns/zhishu.py | 110 ++++++++++++++++++++++ evals/registry/completion_fns/zhishu.yaml | 9 ++ 2 files changed, 119 insertions(+) create mode 100644 evals/completion_fns/zhishu.py create mode 100644 evals/registry/completion_fns/zhishu.yaml diff --git a/evals/completion_fns/zhishu.py b/evals/completion_fns/zhishu.py new file mode 100644 index 0000000000..fda467de1b --- /dev/null +++ b/evals/completion_fns/zhishu.py @@ -0,0 +1,110 @@ +from typing import Any, Optional, Union +import os +import requests + +from openai import OpenAI + +from evals.api import CompletionFn, CompletionResult +from evals.base import CompletionFnSpec +from evals.prompt.base import ( + ChatCompletionPrompt, + CompletionPrompt, + OpenAICreateChatPrompt, + OpenAICreatePrompt, + Prompt, +) +from evals.record import record_sampling +from evals.utils.api_utils import ( + request_with_timeout +) + +default_prompts = { + "activity": "请汇总文献中全部抑制剂(分子请分别用名字和SMILES表达)的结合活性、活性种类(IC50, EC50, TC50, Ki, Kd中的一个),并备注每类结合活性的实验手段。以json格式输出,活性和活性类型的字段名分别为 \"Affinity\" 和 \"Affinity_type\"", +} + + +class Struct: + def __init__(self, **entries): + self.__dict__.update({k: self._wrap(v) for k, v in entries.items()}) + + def _wrap(self, value): + if isinstance(value, (tuple, list, set, frozenset)): + return type(value)([self._wrap(v) for v in value]) + else: + return Struct(**value) if isinstance(value, dict) else value + + def __repr__(self): + return '<%s>' % str('\n '.join('%s : %s' % (k, repr(v)) for (k, v) in self.__dict__.items())) + + +class BaseCompletionResult(CompletionResult): + def __init__(self, raw_data: Any, prompt: Any): + self.raw_data = Struct(raw_data) if type(raw_data) == dict else raw_data + self.prompt = prompt + + def get_completions(self) -> list[str]: + raise NotImplementedError + + +class ZhishuCompletionResult(BaseCompletionResult): + def get_completions(self) -> list[str]: + completions = [] + if self.raw_data: + for choice in self.raw_data.choices: + if choice.message.content is not None: + completions.append(choice.text) + return completions + + +class ZhishuCompletionFn(CompletionFn): + def __init__( + self, + model: Optional[str] = None, + api_base: Optional[str] = None, + api_key: Optional[str] = None, + n_ctx: Optional[int] = None, + extra_options: Optional[dict] = {}, + **kwargs, + ): + self.model = model + self.api_base = api_base + self.api_key = api_key + self.n_ctx = n_ctx + self.extra_options = extra_options + + def __call__( + self, + prompt: Union[str, OpenAICreateChatPrompt], + **kwargs, + ) -> ZhishuCompletionResult: + if not isinstance(prompt, Prompt): + assert ( + isinstance(prompt, str) + or (isinstance(prompt, list) and all(isinstance(token, int) for token in prompt)) + or (isinstance(prompt, list) and all(isinstance(token, str) for token in prompt)) + or (isinstance(prompt, list) and all(isinstance(msg, dict) for msg in prompt)) + ), f"Got type {type(prompt)}, with val {type(prompt[0])} for prompt, expected str or list[int] or list[str] or list[dict[str, str]]" + + prompt = CompletionPrompt( + raw_prompt=prompt, + ) + + openai_create_prompt: OpenAICreatePrompt = prompt.to_formatted_prompt() + + url = f"https://api.zhishuyun.com/openai/gpt-4-all?token={self.api_key or os.environ['ZHISHU_API_KEY']}" + headers = { + "content-type": "application/json" + } + payload = { + "model": "gpt-4-all", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": f"{kwargs['file_link']} {prompt}"} + ] + } + + result = request_with_timeout(requests.post, url, json=payload, headers=headers) + + result = ZhishuCompletionResult(raw_data=result.json(), prompt=openai_create_prompt) + record_sampling(prompt=result.prompt, sampled=result.get_completions()) + return result diff --git a/evals/registry/completion_fns/zhishu.yaml b/evals/registry/completion_fns/zhishu.yaml new file mode 100644 index 0000000000..0a5d712b79 --- /dev/null +++ b/evals/registry/completion_fns/zhishu.yaml @@ -0,0 +1,9 @@ +zhishu/gpt-3.5-turbo: + class: evals.completion_fns.zhishu:ZhishuCompletionFn + args: + model: gpt-3.5-turbo + +zhishu/gpt-4.0-turbo: + class: evals.completion_fns.zhishu:ZhishuCompletionFn + args: + model: gpt-4.0-turbo From 25e80ac4aa367f47777735a3e56291d07f587727 Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Wed, 27 Dec 2023 15:55:59 +0800 Subject: [PATCH 02/22] trial implementation of table_extract tasks --- evals/elsuite/table_extract.py | 183 ++++++++++++++++++ .../registry/evals/00_scipaper_affinity.yaml | 24 +++ 2 files changed, 207 insertions(+) create mode 100644 evals/elsuite/table_extract.py create mode 100644 evals/registry/evals/00_scipaper_affinity.yaml diff --git a/evals/elsuite/table_extract.py b/evals/elsuite/table_extract.py new file mode 100644 index 0000000000..f18b5ac92a --- /dev/null +++ b/evals/elsuite/table_extract.py @@ -0,0 +1,183 @@ +from io import StringIO +import json +import os +import re + +from typing import List, Optional + +import oss2 +from oss2.credentials import EnvironmentVariableCredentialsProvider + +from urllib.parse import parse_qs, urlparse + +from datasets import load_dataset +import pandas as pd +from pydantic import BaseModel + +import evals +import evals.metrics +from evals.api import CompletionFn +from evals.formatting import make_abc +from evals.record import RecorderBase, record_match + +code_pattern = r"```[\s\S]*?\n([\s\S]+)\n```" +json_pattern = r"```json[\s\S]*?\n([\s\S]+)\n```" +csv_pattern = r"```csv[\s\S]*?\n([\s\S]+)\n```" + + +def parse_table_multiindex(table: pd.DataFrame) -> pd.DataFrame: + """ + Parse a table with multiindex columns. + """ + + df = table.copy() + coltypes = {col: type(df[col].iloc[0]) for col in df.columns} + for col, ctype in coltypes.items(): + if ctype == str: + if ":" in df[col].iloc[0] and "," in df[col].iloc[0]: + df[col] = [{key: value for key, value in [pair.split(": ") for pair in data.split(", ")]} for data in + df[col]] + coltypes[col] = dict + dfs = [] + + for col, ctype in coltypes.items(): + if ctype == dict: + d = pd.DataFrame(df.pop(col).tolist()) + dfs.append(d) + df = pd.concat([df] + dfs, axis=1) + return df + + +def init_oss(): + """ + Initialize OSS client. + """ + # Please set OSS_ACCESS_KEY_ID & OSS_ACCESS_KEY_SECRET in your environment variables. + auth = oss2.ProviderAuth(EnvironmentVariableCredentialsProvider()) + + # 设置 Endpoint + endpoint = 'https://oss-cn-beijing.aliyuncs.com' + + # 设置 Bucket + bucket_name = 'dp-filetrans-bj' + bucket = oss2.Bucket(auth, endpoint, bucket_name) + + return bucket + + +class FileSample(BaseModel): + file_name: Optional[str] + file_link: Optional[str] + question: Optional[str] + answerfile_name: Optional[str] + answerfile_link: Optional[str] + compare_fields: List[str] + + +def get_dataset(data_jsonl: str) -> list[FileSample]: + bucket = init_oss() + raw_samples = evals.get_jsonl(data_jsonl) + + for raw_sample in raw_samples: + if "file_name" in raw_sample: + oss_file = "changjunhan/" + os.path.basename(raw_sample["file_name"]) + raw_sample["file_link"] = "https://dp-filetrans-bj.oss-cn-beijing.aliyuncs.com/" + oss_file + + exists = bucket.object_exists(oss_file) + if exists: + print(f"文件 {oss_file} 已存在于 OSS 中。") + else: + # 上传文件 + bucket.put_object_from_file(oss_file, raw_sample["file_name"]) + print(f"文件 {oss_file} 已上传到 OSS。") + elif "file_link" in raw_sample: + local_file = raw_sample["file_name"] if "file_name" in raw_sample else os.path.basename( + raw_sample["file_link"]) + oss_file = "changjunhan/" + os.path.basename(raw_sample["file_link"]) + if not os.path.exists(local_file): + if bucket.object_exists(oss_file): + # 从 OSS 下载文件 + bucket.get_object_to_file(oss_file, local_file) + + samples = [FileSample(**raw_sample) for raw_sample in raw_samples] + return samples + + +class TableExtract(evals.Eval): + def __init__( + self, + completion_fns: list[CompletionFn], + dataset: str, + *args, + instructions: Optional[str] = "", + **kwargs, + ): + super().__init__(completion_fns, *args, **kwargs) + assert len(completion_fns) < 3, "TableExtract only supports 3 completion fns" + self.dataset = dataset + self.instructions = instructions + + def eval_sample(self, sample, rng): + assert isinstance(sample, FileSample) + + prompt = ( + self.instructions + + "\nPlease answer in json format." + + f"\nThe fields should at least contain {sample.compare_fields}" + ) + result = self.completion_fn( + prompt=prompt, + temperature=0.0, + max_tokens=5, + file_name=sample.file_name, + ) + sampled = result.get_completions()[0] + + if "csv" in prompt: + code = re.search(code_pattern, sampled).group() + code_content = re.sub(code_pattern, r"\1", code) + table = pd.read_csv(StringIO(code_content)) + elif "json" in prompt: + code = re.search(code_pattern, sampled).group() + code_content = re.sub(code_pattern, r"\1", code) + table = pd.DataFrame(json.loads(code_content)) + else: + table = pd.DataFrame() + table = parse_table_multiindex(table) + + correct_answer = pd.read_csv(sample.answerfile) + + for field in sample.compare_fields: + match_field = field in table.columns and field in correct_answer.columns + record_match( + correct=match_field, + expected=field, + picked=str(list(table.columns)), + file_name=sample.file_name, + jobtype="match_field" + ) + if match_field: + match_number = table[field].shape[0] == correct_answer[field].shape[0] + record_match( + correct=match_number, + expected=correct_answer[field].shape[0], + picked=table[field].shape[0], + file_name=sample.file_name, + jobtype="match_number" + ) + + for sample_value, correct_value in zip(table[field], correct_answer[field]): + record_match( + correct=(sample_value == correct_value), + expected=correct_value, + picked=sample_value, + file_name=sample.file_name, + jobtype="match_value" + ) + + def run(self, recorder: RecorderBase): + samples = get_dataset(self.dataset) + self.eval_all_samples(recorder, samples) + return { + "accuracy": evals.metrics.get_accuracy(recorder.get_events("match")), + } diff --git a/evals/registry/evals/00_scipaper_affinity.yaml b/evals/registry/evals/00_scipaper_affinity.yaml new file mode 100644 index 0000000000..7a48117aa0 --- /dev/null +++ b/evals/registry/evals/00_scipaper_affinity.yaml @@ -0,0 +1,24 @@ +scipaper_affinity: + id: scipaper_affinity.val.ab-v1 + metrics: [accuracy] +scipaper_affinity.val.ab-v1: + class: evals.elsuite.table_extract:TableExtract + args: + dataset: /Users/chang/AI_projects/Uni-finder/assays/0003001_0004000/answer/data.jsonl + instructions: | + Please summarize the names, SMILES, affinities, target info (protein or cell line), and affinity types (chosen from IC50, EC50, TC50, GI50, Ki, Kd) of all inhibitors in the paper, and output in json format. For example: + ```json + [ + { + "Compound": "5a", + "Name": "d4a", + "SMILES": "Unknown", + "Affinities": { + "5HT1A (IC50)": "2.0 nM", + "5HT1D (IC50)": "8.0 nM", + "5HT-UT (IC50)": "12.6 nM", + "5HT1E (IC50)": ">1000 nM" + } + } + ] + ``` From 4d2e2406f32684a54bac41c6f1093cdde94610bf Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Wed, 27 Dec 2023 21:42:06 +0800 Subject: [PATCH 03/22] bugfixes and add retrieve_native completion_fn --- evals/completion_fns/retrieval_native.py | 97 +++++++++++++++++++++ evals/completion_fns/zhishu.py | 21 ++--- evals/elsuite/table_extract.py | 46 +++++----- evals/registry/completion_fns/retrieve.yaml | 23 +++++ 4 files changed, 149 insertions(+), 38 deletions(-) create mode 100644 evals/completion_fns/retrieval_native.py create mode 100644 evals/registry/completion_fns/retrieve.yaml diff --git a/evals/completion_fns/retrieval_native.py b/evals/completion_fns/retrieval_native.py new file mode 100644 index 0000000000..eb5b801816 --- /dev/null +++ b/evals/completion_fns/retrieval_native.py @@ -0,0 +1,97 @@ +""" +Extending Completion Functions with Embeddings-based retrieval from a fetched dataset +""" +import os +from ast import literal_eval +import time +from typing import Any, Optional, Union + +import numpy as np +from openai import OpenAI + +client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) +import pandas as pd + +from evals.api import CompletionFn, CompletionResult +from evals.prompt.base import ChatCompletionPrompt, CompletionPrompt +from evals.record import record_sampling + + +class RetrievalCompletionResult(CompletionResult): + def __init__(self, response: str) -> None: + self.response = response + + def get_completions(self) -> list[str]: + return [self.response.strip()] + + +class OpenAIRetrievalCompletionFn(CompletionFn): + """ + This Completion Function uses embeddings to retrieve the top k relevant docs from a dataset to the prompt, then adds them to the context before calling the completion. + """ + + def __init__( + self, + model: Optional[str] = None, + instructions: Optional[str] = "You are a helpful assistant on extracting information from files.", + api_base: Optional[str] = None, + api_key: Optional[str] = None, + n_ctx: Optional[int] = None, + extra_options: Optional[dict] = {}, + **kwargs + ): + self.model = model + self.instructions = instructions + self.api_base = api_base + self.api_key = api_key + self.n_ctx = n_ctx + self.extra_options = extra_options + + def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> RetrievalCompletionResult: + """ + Args: + prompt: The prompt to complete, in either text string or Chat format. + kwargs: Additional arguments to pass to the completion function call method. + """ + + assert "file_name" in kwargs, "Must provide a file_name to retrieve." + + file = client.files.create(file=open(kwargs["file_name"], "rb"), purpose='assistants') + + # Create an Assistant (Note model="gpt-3.5-turbo-1106" instead of "gpt-4-1106-preview") + assistant = client.beta.assistants.create( + name="File Assistant", + instructions=self.instructions, + model=self.model, + tools=[{"type": "retrieval"}], + file_ids=[file.id] + ) + + # Create a Thread + thread = client.beta.threads.create() + + # Add a Message to a Thread + print(prompt) + message = client.beta.threads.messages.create(thread_id=thread.id, role="user", + content=prompt + ) + + # Run the Assistant + run = client.beta.threads.runs.create(thread_id=thread.id, assistant_id=assistant.id) + print(run.model_dump_json(indent=4)) + + # If run is 'completed', get messages and print + while True: + # Retrieve the run status + run_status = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) + time.sleep(10) + if run_status.status == 'completed': + messages = client.beta.threads.messages.list(thread_id=thread.id) + answer = messages.data[0].content[0].text.value + break + else: + ### sleep again + time.sleep(2) + print(answer) + record_sampling(prompt=prompt, sampled=answer) + return RetrievalCompletionResult(answer) diff --git a/evals/completion_fns/zhishu.py b/evals/completion_fns/zhishu.py index fda467de1b..41e1ff97e0 100644 --- a/evals/completion_fns/zhishu.py +++ b/evals/completion_fns/zhishu.py @@ -2,13 +2,8 @@ import os import requests -from openai import OpenAI - from evals.api import CompletionFn, CompletionResult -from evals.base import CompletionFnSpec from evals.prompt.base import ( - ChatCompletionPrompt, - CompletionPrompt, OpenAICreateChatPrompt, OpenAICreatePrompt, Prompt, @@ -39,7 +34,7 @@ def __repr__(self): class BaseCompletionResult(CompletionResult): def __init__(self, raw_data: Any, prompt: Any): - self.raw_data = Struct(raw_data) if type(raw_data) == dict else raw_data + self.raw_data = Struct(**raw_data) if type(raw_data) == dict else raw_data self.prompt = prompt def get_completions(self) -> list[str]: @@ -60,6 +55,7 @@ class ZhishuCompletionFn(CompletionFn): def __init__( self, model: Optional[str] = None, + instructions: Optional[str] = "You are a helpful assistant on extracting information from files.", api_base: Optional[str] = None, api_key: Optional[str] = None, n_ctx: Optional[int] = None, @@ -67,6 +63,7 @@ def __init__( **kwargs, ): self.model = model + self.instructions = instructions self.api_base = api_base self.api_key = api_key self.n_ctx = n_ctx @@ -85,26 +82,20 @@ def __call__( or (isinstance(prompt, list) and all(isinstance(msg, dict) for msg in prompt)) ), f"Got type {type(prompt)}, with val {type(prompt[0])} for prompt, expected str or list[int] or list[str] or list[dict[str, str]]" - prompt = CompletionPrompt( - raw_prompt=prompt, - ) - - openai_create_prompt: OpenAICreatePrompt = prompt.to_formatted_prompt() - url = f"https://api.zhishuyun.com/openai/gpt-4-all?token={self.api_key or os.environ['ZHISHU_API_KEY']}" headers = { "content-type": "application/json" } payload = { - "model": "gpt-4-all", + "model": self.model, "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, + {"role": "system", "content": self.instructions}, {"role": "user", "content": f"{kwargs['file_link']} {prompt}"} ] } result = request_with_timeout(requests.post, url, json=payload, headers=headers) - result = ZhishuCompletionResult(raw_data=result.json(), prompt=openai_create_prompt) + result = ZhishuCompletionResult(raw_data=result.json(), prompt=prompt) record_sampling(prompt=result.prompt, sampled=result.get_completions()) return result diff --git a/evals/elsuite/table_extract.py b/evals/elsuite/table_extract.py index f18b5ac92a..ebe5fa8c97 100644 --- a/evals/elsuite/table_extract.py +++ b/evals/elsuite/table_extract.py @@ -68,7 +68,6 @@ def init_oss(): class FileSample(BaseModel): file_name: Optional[str] file_link: Optional[str] - question: Optional[str] answerfile_name: Optional[str] answerfile_link: Optional[str] compare_fields: List[str] @@ -79,26 +78,27 @@ def get_dataset(data_jsonl: str) -> list[FileSample]: raw_samples = evals.get_jsonl(data_jsonl) for raw_sample in raw_samples: - if "file_name" in raw_sample: - oss_file = "changjunhan/" + os.path.basename(raw_sample["file_name"]) - raw_sample["file_link"] = "https://dp-filetrans-bj.oss-cn-beijing.aliyuncs.com/" + oss_file - - exists = bucket.object_exists(oss_file) - if exists: - print(f"文件 {oss_file} 已存在于 OSS 中。") - else: - # 上传文件 - bucket.put_object_from_file(oss_file, raw_sample["file_name"]) - print(f"文件 {oss_file} 已上传到 OSS。") - elif "file_link" in raw_sample: - local_file = raw_sample["file_name"] if "file_name" in raw_sample else os.path.basename( - raw_sample["file_link"]) - oss_file = "changjunhan/" + os.path.basename(raw_sample["file_link"]) - if not os.path.exists(local_file): - if bucket.object_exists(oss_file): - # 从 OSS 下载文件 - bucket.get_object_to_file(oss_file, local_file) - + for ftype in ["", "answer"]: + if f"{ftype}file_name" in raw_sample: + oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_name"]) + raw_sample[f"{ftype}file_link"] = "https://dp-filetrans-bj.oss-cn-beijing.aliyuncs.com/" + oss_file + + exists = bucket.object_exists(oss_file) + if exists: + print(f"文件 {oss_file} 已存在于 OSS 中。") + else: + # 上传文件 + bucket.put_object_from_file(oss_file, raw_sample[f"{ftype}file_name"]) + print(f"文件 {oss_file} 已上传到 OSS。") + elif f"{ftype}file_link" in raw_sample: + local_file = raw_sample[f"{ftype}file_name"] if f"{ftype}file_name" in raw_sample else os.path.basename( + raw_sample[f"{ftype}file_link"]) + oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_link"]) + if not os.path.exists(local_file): + if bucket.object_exists(oss_file): + # 从 OSS 下载文件 + bucket.get_object_to_file(oss_file, local_file) + print(raw_samples) samples = [FileSample(**raw_sample) for raw_sample in raw_samples] return samples @@ -122,7 +122,6 @@ def eval_sample(self, sample, rng): prompt = ( self.instructions - + "\nPlease answer in json format." + f"\nThe fields should at least contain {sample.compare_fields}" ) result = self.completion_fn( @@ -130,6 +129,7 @@ def eval_sample(self, sample, rng): temperature=0.0, max_tokens=5, file_name=sample.file_name, + file_link=sample.file_link ) sampled = result.get_completions()[0] @@ -145,7 +145,7 @@ def eval_sample(self, sample, rng): table = pd.DataFrame() table = parse_table_multiindex(table) - correct_answer = pd.read_csv(sample.answerfile) + correct_answer = pd.read_csv(sample.answerfile_name, header=[0, 1]) for field in sample.compare_fields: match_field = field in table.columns and field in correct_answer.columns diff --git a/evals/registry/completion_fns/retrieve.yaml b/evals/registry/completion_fns/retrieve.yaml new file mode 100644 index 0000000000..648d58a5f9 --- /dev/null +++ b/evals/registry/completion_fns/retrieve.yaml @@ -0,0 +1,23 @@ +retrieval/presidents/gpt-3.5-turbo: + class: evals.completion_fns.retrieval:RetrievalCompletionFn + args: + completion_fn: gpt-3.5-turbo + embeddings_and_text_path: presidents_embeddings.csv + k: 2 + +retrieval/presidents/cot/gpt-3.5-turbo: + class: evals.completion_fns.retrieval:RetrievalCompletionFn + args: + completion_fn: cot/gpt-3.5-turbo + embeddings_and_text_path: presidents_embeddings.csv + k: 2 + +retrieval_native/gpt-3.5-turbo: + class: evals.completion_fns.retrieval_native:OpenAIRetrievalCompletionFn + args: + model: gpt-3.5-turbo-1106 + +retrieval_native/gpt-4-all: + class: evals.completion_fns.retrieval_native:OpenAIRetrievalCompletionFn + args: + model: gpt-4-1106-preview \ No newline at end of file From 686ecd4371eda93f594c695222aa46b11d3f4a22 Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Wed, 27 Dec 2023 22:58:06 +0800 Subject: [PATCH 04/22] add fuzzy_compare for table content --- evals/elsuite/table_extract.py | 49 ++++++++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 8 deletions(-) diff --git a/evals/elsuite/table_extract.py b/evals/elsuite/table_extract.py index ebe5fa8c97..0f788bf684 100644 --- a/evals/elsuite/table_extract.py +++ b/evals/elsuite/table_extract.py @@ -3,21 +3,17 @@ import os import re -from typing import List, Optional +from typing import List, Optional, Tuple, Union import oss2 from oss2.credentials import EnvironmentVariableCredentialsProvider -from urllib.parse import parse_qs, urlparse - -from datasets import load_dataset import pandas as pd from pydantic import BaseModel import evals import evals.metrics from evals.api import CompletionFn -from evals.formatting import make_abc from evals.record import RecorderBase, record_match code_pattern = r"```[\s\S]*?\n([\s\S]+)\n```" @@ -70,7 +66,7 @@ class FileSample(BaseModel): file_link: Optional[str] answerfile_name: Optional[str] answerfile_link: Optional[str] - compare_fields: List[str] + compare_fields: List[Union[str, Tuple]] def get_dataset(data_jsonl: str) -> list[FileSample]: @@ -98,11 +94,48 @@ def get_dataset(data_jsonl: str) -> list[FileSample]: if bucket.object_exists(oss_file): # 从 OSS 下载文件 bucket.get_object_to_file(oss_file, local_file) + raw_sample["compare_fields"] = [field if type(field) == str else tuple(field) for field in + raw_sample["compare_fields"]] print(raw_samples) samples = [FileSample(**raw_sample) for raw_sample in raw_samples] return samples +def fuzzy_compare(a: str, b: str) -> bool: + """ + Compare two strings with fuzzy matching. + """ + + def standardize_unit(s: str) -> str: + """ + Standardize a (affinity) string to common units. + """ + mark = "" if re.search(r"[><=]", s) is None else re.search(r"[><=]", s).group() + unit = s.rstrip()[-2:] + number = re.search(r"[0-9.\+\-]+", s).group() + + if unit in ["µM", "uM"]: + unit = "nM" + number *= 1000 + elif unit in ["mM", "mm"]: + unit = "nM" + number *= 1000000 + + if mark == "=": + mark = "" + return f"{mark}{number:.1f} {unit}" + + unit_str = ["nM", "uM", "µM", "mM"] + a = a.strip() + b = b.strip() + if a[-2:] in unit_str and b[-2:] in unit_str: + a = standardize_unit(a) + b = standardize_unit(b) + return a == b + else: + return (a.lower() in b.lower()) or (b.lower() in a.lower()) + + class TableExtract(evals.Eval): def __init__( self, @@ -145,7 +178,7 @@ def eval_sample(self, sample, rng): table = pd.DataFrame() table = parse_table_multiindex(table) - correct_answer = pd.read_csv(sample.answerfile_name, header=[0, 1]) + correct_answer = pd.read_csv(sample.answerfile_name, header=[0, 1]).astype(str) for field in sample.compare_fields: match_field = field in table.columns and field in correct_answer.columns @@ -168,7 +201,7 @@ def eval_sample(self, sample, rng): for sample_value, correct_value in zip(table[field], correct_answer[field]): record_match( - correct=(sample_value == correct_value), + correct=fuzzy_compare(sample_value, correct_value), expected=correct_value, picked=sample_value, file_name=sample.file_name, From 19cd84a838a17d2584a94b56e8b2585ee0b9c6b6 Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Thu, 28 Dec 2023 10:01:37 +0800 Subject: [PATCH 05/22] add fuzzy_normalize for table headers --- evals/elsuite/table_extract.py | 73 +++++++++++++++++++++++++--------- 1 file changed, 55 insertions(+), 18 deletions(-) diff --git a/evals/elsuite/table_extract.py b/evals/elsuite/table_extract.py index 0f788bf684..9a27aa58c4 100644 --- a/evals/elsuite/table_extract.py +++ b/evals/elsuite/table_extract.py @@ -16,9 +16,9 @@ from evals.api import CompletionFn from evals.record import RecorderBase, record_match -code_pattern = r"```[\s\S]*?\n([\s\S]+)\n```" -json_pattern = r"```json[\s\S]*?\n([\s\S]+)\n```" -csv_pattern = r"```csv[\s\S]*?\n([\s\S]+)\n```" +code_pattern = r"```[\s\S]*?\n([\s\S]+?)\n```" +json_pattern = r"```json[\s\S]*?\n([\s\S]+?)\n```" +csv_pattern = r"```csv[\s\S]*?\n([\s\S]+?)\n```" def parse_table_multiindex(table: pd.DataFrame) -> pd.DataFrame: @@ -27,20 +27,26 @@ def parse_table_multiindex(table: pd.DataFrame) -> pd.DataFrame: """ df = table.copy() - coltypes = {col: type(df[col].iloc[0]) for col in df.columns} - for col, ctype in coltypes.items(): - if ctype == str: - if ":" in df[col].iloc[0] and "," in df[col].iloc[0]: - df[col] = [{key: value for key, value in [pair.split(": ") for pair in data.split(", ")]} for data in - df[col]] - coltypes[col] = dict - dfs = [] - - for col, ctype in coltypes.items(): - if ctype == dict: - d = pd.DataFrame(df.pop(col).tolist()) - dfs.append(d) - df = pd.concat([df] + dfs, axis=1) + if df.columns.nlevels == 1: + coltypes = {col: type(df[col].iloc[0]) for col in df.columns} + for col, ctype in coltypes.items(): + if ctype == str: + if ":" in df[col].iloc[0] and "," in df[col].iloc[0]: + df[col] = [{key: value for key, value in [pair.split(": ") for pair in data.split(", ")]} for data + in df[col]] + coltypes[col] = dict + dfs = [] + + for col, ctype in coltypes.items(): + if ctype == dict: + d = pd.DataFrame(df.pop(col).tolist()) + d.columns = pd.MultiIndex.from_tuples([(col, fuzzy_normalize(key)) for key in d.columns]) + dfs.append(d) + df.columns = pd.MultiIndex.from_tuples([(col, "") for col in df.columns]) + df = pd.concat([df] + dfs, axis=1) + if df.columns.nlevels > 1: + df.columns = pd.MultiIndex.from_tuples([(col, fuzzy_normalize(subcol)) for col, subcol in df.columns]) + return df @@ -136,6 +142,31 @@ def standardize_unit(s: str) -> str: return (a.lower() in b.lower()) or (b.lower() in a.lower()) +def fuzzy_normalize(s): + """ 标准化字符串 """ + # 定义需要移除的单位和符号 + units = ["µM", "µg/mL", "nM"] + for unit in units: + s = s.replace(unit, "") + + # 定义特定关键字 + keywords = ["IC50", "EC50", "TC50", "GI50", "Ki", "Kd"] + + # 移除非字母数字的字符,除了空格 + s = re.sub(r'[^\w\s]', '', s) + + # 分割字符串为单词列表 + words = s.split() + + # 将关键字移到末尾 + reordered_words = [word for word in words if word not in keywords] + keywords_in_string = [word for word in words if word in keywords] + reordered_words.extend(keywords_in_string) + + # 重新组合为字符串 + return ' '.join(reordered_words) + + class TableExtract(evals.Eval): def __init__( self, @@ -170,6 +201,9 @@ def eval_sample(self, sample, rng): code = re.search(code_pattern, sampled).group() code_content = re.sub(code_pattern, r"\1", code) table = pd.read_csv(StringIO(code_content)) + if pd.isna(table.iloc[0, 0]): + table = pd.read_csv(StringIO(code_content), header=[0, 1]) + elif "json" in prompt: code = re.search(code_pattern, sampled).group() code_content = re.sub(code_pattern, r"\1", code) @@ -177,10 +211,13 @@ def eval_sample(self, sample, rng): else: table = pd.DataFrame() table = parse_table_multiindex(table) + table.to_csv("temp1.csv") - correct_answer = pd.read_csv(sample.answerfile_name, header=[0, 1]).astype(str) + correct_answer = parse_table_multiindex(pd.read_csv(sample.answerfile_name, header=[0, 1]).astype(str)) for field in sample.compare_fields: + if type(field) == tuple: + field = (field[0], fuzzy_normalize(field[1])) match_field = field in table.columns and field in correct_answer.columns record_match( correct=match_field, From 035de6e68f65ca236948763f200bbd25cd5fbeb8 Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Fri, 29 Dec 2023 12:41:06 +0800 Subject: [PATCH 06/22] add uni-finder completion_fn and separated format tests (json/csv) --- evals/completion_fns/uni_finder.py | 85 +++++++++++++++++++ evals/elsuite/table_extract.py | 19 +++-- evals/registry/completion_fns/uni_finder.yaml | 9 ++ .../data/00_scipaper_affinity/samples.jsonl | 3 + .../registry/evals/00_scipaper_affinity.yaml | 24 ++++-- 5 files changed, 129 insertions(+), 11 deletions(-) create mode 100644 evals/completion_fns/uni_finder.py create mode 100644 evals/registry/completion_fns/uni_finder.yaml create mode 100644 evals/registry/data/00_scipaper_affinity/samples.jsonl diff --git a/evals/completion_fns/uni_finder.py b/evals/completion_fns/uni_finder.py new file mode 100644 index 0000000000..32e61f324f --- /dev/null +++ b/evals/completion_fns/uni_finder.py @@ -0,0 +1,85 @@ +""" +Extending Completion Functions with Embeddings-based retrieval from a fetched dataset +""" +import os +import time +import requests +from typing import Any, Optional, Union + +import numpy as np +from openai import OpenAI + +import pandas as pd + +from evals.api import CompletionFn, CompletionResult +from evals.record import record_sampling +from evals.utils.api_utils import ( + request_with_timeout +) + + +class UniFinderCompletionResult(CompletionResult): + def __init__(self, response: str) -> None: + self.response = response + + def get_completions(self) -> list[str]: + return [self.response.strip()] + + +class UniFinderCompletionFn(CompletionFn): + """ + This Completion Function uses embeddings to retrieve the top k relevant docs from a dataset to the prompt, then adds them to the context before calling the completion. + """ + + def __init__( + self, + model: Optional[str] = None, + instructions: Optional[str] = "You are a helpful assistant on extracting information from files.", + api_base: Optional[str] = None, + api_key: Optional[str] = None, + n_ctx: Optional[int] = None, + extra_options: Optional[dict] = {}, + **kwargs + ): + self.model = model + self.instructions = instructions + self.api_base = api_base or os.environ.get("UNIFINDER_API_BASE") + self.api_key = api_key or os.environ.get("UNIFINDER_API_KEY") + self.n_ctx = n_ctx + self.extra_options = extra_options + + def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> UniFinderCompletionResult: + """ + Args: + prompt: The prompt to complete, in either text string or Chat format. + kwargs: Additional arguments to pass to the completion function call method. + """ + + pdf_token = [] + if "file_name" in kwargs: + url = f"{self.api_base}/api/external/upload_pdf" + pdf_parse_mode = 'fast' # or 'precise', 指定使用的pdf解析版本 + files = {'file': open(kwargs["file_name"], 'rb')} + data = { + 'pdf_parse_mode': pdf_parse_mode, + 'api_key': self.api_key + } + response = requests.post(url, data=data, files=files).json() + pdf_id = response['pdf_token'] # 获得pdf的id,表示上传成功,后续可以使用这个id来指定pdf + pdf_token.append(pdf_id) + + assert "file_name" in kwargs, "Must provide a file_name to retrieve." + + url = f"{self.api_base}/api/external/chatpdf" + + payload = { + "model_engine": self.model, + "pdf_token": pdf_token, + "query": prompt, + 'api_key': self.api_key + } + response = requests.post(url, json=payload).json() + answer = response['answer'] + print(answer) + record_sampling(prompt=prompt, sampled=answer) + return UniFinderCompletionResult(answer) diff --git a/evals/elsuite/table_extract.py b/evals/elsuite/table_extract.py index 9a27aa58c4..cc1447a130 100644 --- a/evals/elsuite/table_extract.py +++ b/evals/elsuite/table_extract.py @@ -1,6 +1,7 @@ from io import StringIO import json import os +from pathlib import Path import re from typing import List, Optional, Tuple, Union @@ -99,6 +100,7 @@ def get_dataset(data_jsonl: str) -> list[FileSample]: if not os.path.exists(local_file): if bucket.object_exists(oss_file): # 从 OSS 下载文件 + Path(local_file).parent.mkdir(parents=True, exist_ok=True) bucket.get_object_to_file(oss_file, local_file) raw_sample["compare_fields"] = [field if type(field) == str else tuple(field) for field in raw_sample["compare_fields"]] @@ -131,13 +133,16 @@ def standardize_unit(s: str) -> str: mark = "" return f"{mark}{number:.1f} {unit}" - unit_str = ["nM", "uM", "µM", "mM"] + unit_str = ["nM", "uM", "µM", "mM", "%", " %"] + nan_str = ["n/a", "nan", "na", "nd", "not determined", "not tested"] a = a.strip() b = b.strip() - if a[-2:] in unit_str and b[-2:] in unit_str: + if (a[-2:] in unit_str or a[-1] in unit_str) and (b[-2:] in unit_str or b[-1] in unit_str): a = standardize_unit(a) b = standardize_unit(b) return a == b + elif a.lower() in nan_str and b.lower() in nan_str: + return True else: return (a.lower() in b.lower()) or (b.lower() in a.lower()) @@ -206,14 +211,16 @@ def eval_sample(self, sample, rng): elif "json" in prompt: code = re.search(code_pattern, sampled).group() - code_content = re.sub(code_pattern, r"\1", code) + code_content = re.sub(code_pattern, r"\1", code).replace("\"", "") table = pd.DataFrame(json.loads(code_content)) else: table = pd.DataFrame() - table = parse_table_multiindex(table) - table.to_csv("temp1.csv") + table = parse_table_multiindex(table).sort_values(by="Compound") + + correct_answer = parse_table_multiindex( + pd.read_csv(sample.answerfile_name, header=[0, 1]).astype(str)).sort_values(by="Compound") - correct_answer = parse_table_multiindex(pd.read_csv(sample.answerfile_name, header=[0, 1]).astype(str)) + table.to_csv(sample.answerfile_name.replace(".csv", "_output.csv")) for field in sample.compare_fields: if type(field) == tuple: diff --git a/evals/registry/completion_fns/uni_finder.yaml b/evals/registry/completion_fns/uni_finder.yaml new file mode 100644 index 0000000000..b31c691445 --- /dev/null +++ b/evals/registry/completion_fns/uni_finder.yaml @@ -0,0 +1,9 @@ +uni_finder/gpt-3.5-turbo: + class: evals.completion_fns.uni_finder:UniFinderCompletionFn + args: + model: gpt35 + +uni_finder/gpt-4-all: + class: evals.completion_fns.uni_finder:UniFinderCompletionFn + args: + model: gpt4 diff --git a/evals/registry/data/00_scipaper_affinity/samples.jsonl b/evals/registry/data/00_scipaper_affinity/samples.jsonl new file mode 100644 index 0000000000..4d1ea75b61 --- /dev/null +++ b/evals/registry/data/00_scipaper_affinity/samples.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c8151e952830ec4993a74d682cadf973f294d9faaa4cdae0908ca7c4e62dcc73 +size 1304 diff --git a/evals/registry/evals/00_scipaper_affinity.yaml b/evals/registry/evals/00_scipaper_affinity.yaml index 7a48117aa0..de865a86a9 100644 --- a/evals/registry/evals/00_scipaper_affinity.yaml +++ b/evals/registry/evals/00_scipaper_affinity.yaml @@ -1,18 +1,17 @@ scipaper_affinity: id: scipaper_affinity.val.ab-v1 metrics: [accuracy] -scipaper_affinity.val.ab-v1: +scipaper_affinity.val.json: class: evals.elsuite.table_extract:TableExtract args: - dataset: /Users/chang/AI_projects/Uni-finder/assays/0003001_0004000/answer/data.jsonl + dataset: 00_scipaper_affinity/samples.jsonl instructions: | - Please summarize the names, SMILES, affinities, target info (protein or cell line), and affinity types (chosen from IC50, EC50, TC50, GI50, Ki, Kd) of all inhibitors in the paper, and output in json format. For example: + Please give a complete list of names, affinities, target info (protein or cell line), and affinity types (chosen from IC50, EC50, TC50, GI50, Ki, Kd) of all inhibitors in the paper. If there are multiple tables, combine them. Don't give me reference. Output in json format. For example: ```json [ { "Compound": "5a", - "Name": "d4a", - "SMILES": "Unknown", + "Name": "Aspirin", "Affinities": { "5HT1A (IC50)": "2.0 nM", "5HT1D (IC50)": "8.0 nM", @@ -22,3 +21,18 @@ scipaper_affinity.val.ab-v1: } ] ``` + +scipaper_affinity.val.csv: + class: evals.elsuite.table_extract:TableExtract + args: + dataset: 00_scipaper_affinity/samples.jsonl + instructions: | + Please give a complete list of affinities, target info (protein or cell line), and affinity types (chosen from IC50, EC50, TC50, GI50, Ki, Kd) of all inhibitors in the paper. + 1. Find all the tables with relevant information + 2. Output in csv format with multiindex (Affinities, protein/cell line), write units not in header but in the value like "10.5 µM". Quote the value if it has comma! For example: + ```csv + Compound,Name,Affinities,Affinities,Affinities,Affinities + ,,5HT1A (IC50),5HT1D (IC50),5HT-UT (IC50),5HT1E () + "5a","1,2-dimethyl Aspirin",2.0 nM,8.0 nM,12.6 nM,>1000 nM + ``` + 3. If there are multiple tables, concat them. Don't give me reference or using "...", give me complete table! \ No newline at end of file From 7d91b0a96d2c0740ac89cb3afc99e129c2909e87 Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Fri, 29 Dec 2023 16:00:24 +0800 Subject: [PATCH 07/22] basic mlops loggers --- evals/cli/oaieval.py | 28 +++ evals/elsuite/table_extract.py | 99 +++++---- evals/reporters/DPTracking.py | 104 ++++++++++ evals/reporters/Feishu.py | 369 +++++++++++++++++++++++++++++++++ evals/reporters/WandB.py | 43 ++++ evals/reporters/__init__.py | 0 6 files changed, 606 insertions(+), 37 deletions(-) create mode 100644 evals/reporters/DPTracking.py create mode 100644 evals/reporters/Feishu.py create mode 100644 evals/reporters/WandB.py create mode 100644 evals/reporters/__init__.py diff --git a/evals/cli/oaieval.py b/evals/cli/oaieval.py index 20b7d4c3bf..a17898b164 100644 --- a/evals/cli/oaieval.py +++ b/evals/cli/oaieval.py @@ -2,9 +2,12 @@ This file defines the `oaieval` CLI for running evals. """ import argparse +import json import logging +import re import shlex import sys +from pathlib import Path from typing import Any, Mapping, Optional, Union, cast import openai @@ -48,6 +51,7 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument( "--log_to_file", type=str, default=None, help="Log to a file instead of stdout" ) + parser.add_argument("--mlops", type=str, default=None) parser.add_argument( "--registry_path", type=str, @@ -106,6 +110,7 @@ class OaiEvalArguments(argparse.Namespace): user: str record_path: Optional[str] log_to_file: Optional[str] + mlops: Optional[str] registry_path: list[str] debug: bool local_run: bool @@ -229,6 +234,29 @@ def to_number(x: str) -> Union[int, float, str]: logger.info("Final report:") for key, value in result.items(): logger.info(f"{key}: {value}") + + if args.mlops: + import pandas as pd + with open(record_path, "r") as f: + events_df = pd.read_json(f, lines=True) + + run_config = events_df.loc[0, "spec"] + matches_df = events_df[events_df.type == "match"].reset_index(drop=True) + matches_df = matches_df.join(pd.json_normalize(matches_df.data)) + + matches_df["doi"] = [re.sub("__([0-9]+)__", "(\1)", Path(f).stem).replace("_", "/") for f in matches_df["file_name"]] + + # TODO: compare on different completion_functions + accuracy_by_type_and_file = matches_df.groupby(["jobtype", "doi"])['correct'].mean() + accuracy_by_type = matches_df.groupby(["jobtype"])['correct'].mean() + + logger_data = {} + + config_logger = json.load(open(args.mlops, 'r')) + if "dp_mlops" in config_logger: + from evals.reporters.DPTracking import DPTrackingReporter + DPTrackingReporter.report_run(config_logger, run_config, logger_data, step=0) + return run_spec.run_id diff --git a/evals/elsuite/table_extract.py b/evals/elsuite/table_extract.py index cc1447a130..43aef80c59 100644 --- a/evals/elsuite/table_extract.py +++ b/evals/elsuite/table_extract.py @@ -120,7 +120,7 @@ def standardize_unit(s: str) -> str: """ mark = "" if re.search(r"[><=]", s) is None else re.search(r"[><=]", s).group() unit = s.rstrip()[-2:] - number = re.search(r"[0-9.\+\-]+", s).group() + number = float(re.search(r"[0-9.\+\-]+", s).group()) if unit in ["µM", "uM"]: unit = "nM" @@ -148,28 +148,30 @@ def standardize_unit(s: str) -> str: def fuzzy_normalize(s): - """ 标准化字符串 """ - # 定义需要移除的单位和符号 - units = ["µM", "µg/mL", "nM"] - for unit in units: - s = s.replace(unit, "") - - # 定义特定关键字 - keywords = ["IC50", "EC50", "TC50", "GI50", "Ki", "Kd"] + if s.startswith("Unnamed"): + return "" + else: + """ 标准化字符串 """ + # 定义需要移除的单位和符号 + units = ["µM", "µg/mL", "nM"] + for unit in units: + s = s.replace(unit, "") - # 移除非字母数字的字符,除了空格 - s = re.sub(r'[^\w\s]', '', s) + # 定义特定关键字 + keywords = ["IC50", "EC50", "TC50", "GI50", "Ki", "Kd"] - # 分割字符串为单词列表 - words = s.split() + # 移除非字母数字的字符,除了空格 + s = re.sub(r'[^\w\s]', '', s) - # 将关键字移到末尾 - reordered_words = [word for word in words if word not in keywords] - keywords_in_string = [word for word in words if word in keywords] - reordered_words.extend(keywords_in_string) + # 分割字符串为单词列表 + words = s.split() - # 重新组合为字符串 - return ' '.join(reordered_words) + # 将关键字移到末尾 + reordered_words = [word for word in words if word not in keywords] + keywords_in_string = [word for word in words if word in keywords] + reordered_words.extend(keywords_in_string) + # 重新组合为字符串 + return ' '.join(reordered_words) class TableExtract(evals.Eval): @@ -202,30 +204,43 @@ def eval_sample(self, sample, rng): ) sampled = result.get_completions()[0] - if "csv" in prompt: - code = re.search(code_pattern, sampled).group() - code_content = re.sub(code_pattern, r"\1", code) - table = pd.read_csv(StringIO(code_content)) - if pd.isna(table.iloc[0, 0]): - table = pd.read_csv(StringIO(code_content), header=[0, 1]) - - elif "json" in prompt: - code = re.search(code_pattern, sampled).group() - code_content = re.sub(code_pattern, r"\1", code).replace("\"", "") - table = pd.DataFrame(json.loads(code_content)) - else: - table = pd.DataFrame() - table = parse_table_multiindex(table).sort_values(by="Compound") - - correct_answer = parse_table_multiindex( - pd.read_csv(sample.answerfile_name, header=[0, 1]).astype(str)).sort_values(by="Compound") + correct_answer = parse_table_multiindex(pd.read_csv(sample.answerfile_name, header=[0, 1]).astype(str)) + correct_answer = correct_answer.sort_values(by=("Compound", "")) + correct_str = correct_answer.to_csv() + + try: + if "csv" in prompt: + code = re.search(code_pattern, sampled).group() + code_content = re.sub(code_pattern, r"\1", code) + table = pd.read_csv(StringIO(code_content)) + if pd.isna(table.iloc[0, 0]): + table = pd.read_csv(StringIO(code_content), header=[0, 1]) + + elif "json" in prompt: + code = re.search(code_pattern, sampled).group() + code_content = re.sub(code_pattern, r"\1", code).replace("\"", "") + table = pd.DataFrame(json.loads(code_content)) + else: + table = pd.DataFrame() + table = parse_table_multiindex(table).sort_values(by="Compound") + except: + record_match( + correct=False, + expected=correct_str, + picked=sampled, + file_name=sample.file_name, + jobtype="match_all" + ) + return table.to_csv(sample.answerfile_name.replace(".csv", "_output.csv")) + match_all = True for field in sample.compare_fields: if type(field) == tuple: field = (field[0], fuzzy_normalize(field[1])) match_field = field in table.columns and field in correct_answer.columns + match_all = match_all and match_field record_match( correct=match_field, expected=field, @@ -235,6 +250,7 @@ def eval_sample(self, sample, rng): ) if match_field: match_number = table[field].shape[0] == correct_answer[field].shape[0] + match_all = match_all and match_number record_match( correct=match_number, expected=correct_answer[field].shape[0], @@ -244,13 +260,22 @@ def eval_sample(self, sample, rng): ) for sample_value, correct_value in zip(table[field], correct_answer[field]): + match_value = fuzzy_compare(str(sample_value), str(correct_value)) + match_all = match_all and match_value record_match( - correct=fuzzy_compare(sample_value, correct_value), + correct=match_value, expected=correct_value, picked=sample_value, file_name=sample.file_name, jobtype="match_value" ) + record_match( + correct=match_all, + expected=correct_str, + picked=table.to_string(), + file_name=sample.file_name, + jobtype="match_all" + ) def run(self, recorder: RecorderBase): samples = get_dataset(self.dataset) diff --git a/evals/reporters/DPTracking.py b/evals/reporters/DPTracking.py new file mode 100644 index 0000000000..04aa695ef9 --- /dev/null +++ b/evals/reporters/DPTracking.py @@ -0,0 +1,104 @@ +import glob +import os +import time +import uuid +from copy import deepcopy +from datetime import datetime +from pathlib import Path +from typing import Dict, Union, List, Any + +import numpy as np +import pandas as pd +import aim +from PIL import Image +from rdkit import Chem + + +class DPTrackingReporter: + @staticmethod + def _convert_logger_table(df: pd.DataFrame) -> aim.Table: + aim_df = deepcopy(df) + if aim_df.shape[0] == 0: + return aim.Table(aim_df) + for col in aim_df.columns: + i = 0 + while not aim_df.loc[i, col]: + i += 1 + if i == aim_df.shape[0]: + i = 0 + break + data0 = aim_df.loc[i, col] + if isinstance(data0, Chem.Mol): + molfiles = [] + tmpdir = f"aim-tmp-{uuid.uuid4().hex}" + Path(tmpdir).mkdir(exist_ok=True, parents=True) + for i, mol in enumerate(aim_df[col]): + if mol: + molfile = f"{tmpdir}/{i}.sdf" + Chem.MolToMolFile(mol, molfile) + molfiles.append(molfile) + else: + molfiles.append(None) + aim_df[col] = [aim.Molecule(molfile) if molfile else None for molfile in molfiles] + elif isinstance(data0, Image.Image): + imgfiles = [] + tmpdir = f"aim-tmp-{uuid.uuid4().hex}" + Path(tmpdir).mkdir(exist_ok=True, parents=True) + for i, img in enumerate(aim_df[col]): + if img: + imgfile = f"{tmpdir}/{i}.png" + img.save(imgfile) + imgfiles.append(imgfile) + else: + imgfiles.append(None) + aim_df[col] = [aim.TableImage(imgfile) if imgfile else None for imgfile in imgfiles] + return aim.Table(aim_df) + + @staticmethod + def _convert_logger_data(v: Any) -> Any: + import matplotlib.pyplot as plt + try: + import plotly.graph_objects as go + except ImportError: + go = plt + if type(v) in [go.Figure, plt.Figure]: + return aim.Figure(v) + if type(v) in [Image.Image] or (type(v) == str and Path(v).exists() and Path(v).suffix in [".png", ".jpg"]): + return aim.Image(v) + if type(v) in [pd.DataFrame]: + return DPTrackingReporter._convert_logger_table(v) + if type(v) in [np.ndarray, list]: + return aim.Distribution(v) + return v + + @staticmethod + def report_run(config_logger: Dict, config_run: Dict = {}, logger_data: Dict = {}, step: int = -1): + dp_mlops_config = config_logger["dp_mlops"] + + # Experiment Tracking + os.environ["AIM_ACCESS_TOKEN"] = dp_mlops_config["aim_personal_token"] + print('debug report_sampler: run_hash', config_logger["hash"], datetime.now()) + # os.environ['AIM_UNSAFE_SESSION_COOKIE'] = config_logger["hash"] + run = aim.Run( + experiment=config_logger["project"], + run_hash=config_logger.get("hash", None), + repo=dp_mlops_config["aim_repo"] + ) + # run = Run(experiment=config_logger["project"], repo=dp_mlops_config["aim_repo"]) + run.name = config_logger["name"] + run.hparams["config"] = config_run + for tag in set([config_logger["name"]] + dp_mlops_config.get("tags", [])): + if tag and tag.lower() not in [t.lower() for t in run.props.tags]: + print(tag.lower(), run.props.tags) + run.add_tag(tag.lower()) + + DPTrackingReporter._convert_logger_data(logger_data) + + for key, value in logger_data.items(): + if "/" not in key or "kcal/mol" in key: + run.track(value, name=key, context={}) + else: + key, context_str = key.split("/") + context_dict = {k: v for k, v in [kv.split(":") for kv in context_str.split(",")]} + run.track(value, name=key, context={**context_dict}) + run.close() diff --git a/evals/reporters/Feishu.py b/evals/reporters/Feishu.py new file mode 100644 index 0000000000..65ccc610aa --- /dev/null +++ b/evals/reporters/Feishu.py @@ -0,0 +1,369 @@ +import os +from pathlib import Path +import json +import datetime + +from typing import Dict, Union, List + +import requests + +# 时间、实验名、项目、成功体系占比、Protocol、imgkey、Tracking链接、工作流链接 +FEISHU_MESSAGE_STRING = \ + ''' +{ + "config": { + "wide_screen_mode": true + }, + "elements": [ + { + "fields": [ + { + "is_short": true, + "text": { + "content": "**🕐 时间:**\n%s", + "tag": "lark_md" + } + }, + { + "is_short": true, + "text": { + "content": "**🔢 实验名:**\n%s", + "tag": "lark_md" + } + }, + { + "is_short": false, + "text": { + "content": "", + "tag": "lark_md" + } + }, + { + "is_short": true, + "text": { + "content": "**📋 项目:**\n%s", + "tag": "lark_md" + } + }, + { + "is_short": true, + "text": { + "content": "**📋 成功体系:**\n%s", + "tag": "lark_md" + } + } + ], + "tag": "div" + }, + { + "fields": [ + { + "is_short": false, + "text": { + "content": "**🕐 Protocol:**\n%s", + "tag": "lark_md" + } + } + ], + "tag": "div" + }, + { + "alt": { + "content": "", + "tag": "plain_text" + }, + "img_key": "%s", + "tag": "img", + "title": { + "content": "Metrics 汇总:", + "tag": "lark_md" + } + }, + { + "actions": [ + { + "tag": "button", + "text": { + "content": "跟进处理", + "tag": "plain_text" + }, + "type": "primary", + "value": { + "key1": "value1" + } + }, + { + "options": [ + { + "text": { + "content": "屏蔽10分钟", + "tag": "plain_text" + }, + "value": "1" + }, + { + "text": { + "content": "屏蔽30分钟", + "tag": "plain_text" + }, + "value": "2" + }, + { + "text": { + "content": "屏蔽1小时", + "tag": "plain_text" + }, + "value": "3" + }, + { + "text": { + "content": "屏蔽24小时", + "tag": "plain_text" + }, + "value": "4" + } + ], + "placeholder": { + "content": "暂时屏蔽实验跟踪", + "tag": "plain_text" + }, + "tag": "select_static", + "value": { + "key": "value" + } + } + ], + "tag": "action" + }, + { + "tag": "hr" + }, + { + "tag": "div", + "text": { + "content": "📝 [Tracking链接](%s) | 🙋 [工作流链接](%s)", + "tag": "lark_md" + } + } + ], + "header": { + "template": "green", + "title": { + "content": "IFD 实验跟踪", + "tag": "plain_text" + } + } +} +''' + +FEISHU_MESSAGE = { + "config": { + "wide_screen_mode": True + }, + "elements": [ + { + "fields": [ + { + "is_short": True, + "text": { + "content": "**🕐 时间:**\n%s", + "tag": "lark_md" + } + }, + { + "is_short": True, + "text": { + "content": "**🔢 实验名:**\n%s", + "tag": "lark_md" + } + }, + { + "is_short": False, + "text": { + "content": "", + "tag": "lark_md" + } + }, + { + "is_short": True, + "text": { + "content": "**📋 项目:**\n%s", + "tag": "lark_md" + } + }, + { + "is_short": True, + "text": { + "content": "**📋 成功体系:**\n%s", + "tag": "lark_md" + } + } + ], + "tag": "div" + }, + { + "fields": [ + { + "is_short": False, + "text": { + "content": "**🕐 Protocol:**\n%s", + "tag": "lark_md" + } + } + ], + "tag": "div" + }, + { + "alt": { + "content": "", + "tag": "plain_text" + }, + "img_key": "%s", + "tag": "img", + "title": { + "content": "Metrics 汇总:", + "tag": "lark_md" + } + }, + { + "actions": [ + { + "tag": "button", + "text": { + "content": "跟进处理", + "tag": "plain_text" + }, + "type": "primary", + "value": { + "key1": "value1" + } + }, + { + "options": [ + { + "text": { + "content": "屏蔽10分钟", + "tag": "plain_text" + }, + "value": "1" + }, + { + "text": { + "content": "屏蔽30分钟", + "tag": "plain_text" + }, + "value": "2" + }, + { + "text": { + "content": "屏蔽1小时", + "tag": "plain_text" + }, + "value": "3" + }, + { + "text": { + "content": "屏蔽24小时", + "tag": "plain_text" + }, + "value": "4" + } + ], + "placeholder": { + "content": "暂时屏蔽实验跟踪", + "tag": "plain_text" + }, + "tag": "select_static", + "value": { + "key": "value" + } + } + ], + "tag": "action" + }, + { + "tag": "hr" + }, + { + "tag": "div", + "text": { + "content": "📝 [Tracking链接](%s) | 🙋 [工作流链接](%s)", + "tag": "lark_md" + } + } + ], + "header": { + "template": "green", + "title": { + "content": "IFD 实验跟踪", + "tag": "plain_text" + } + } +} + + +class FeishuReporter: + @staticmethod + def _get_tenant_token(app_id: str = "cli_a301e6759d32500c", app_secret: str = "uLiHOmf0QOQRkhwymy8AmfHWykMQaMFk"): + url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal" + + payload = json.dumps({ + "app_id": app_id, + "app_secret": app_secret + }) + headers = { + 'Content-Type': 'application/json' + } + response = requests.request("POST", url, headers=headers, data=payload) + response.raise_for_status() + data = response.json() + assert data['code'] == 0 + return data['tenant_access_token'] + + @staticmethod + def _upload_image(file_path, type='image/png', app_id: str = "cli_a301e6759d32500c", + app_secret: str = "uLiHOmf0QOQRkhwymy8AmfHWykMQaMFk"): + url = "https://open.feishu.cn/open-apis/im/v1/images" + payload = {'image_type': 'message'} + files = [ + ('image', (Path(file_path).stem, open(file_path, 'rb'), type)) + ] + token = FeishuReporter._get_tenant_token(app_id=app_id, app_secret=app_secret) + headers = { + 'Authorization': f'Bearer {token}' + } + response = requests.request("POST", url, headers=headers, data=payload, files=files) + response.raise_for_status() + data = response.json() + assert data['code'] == 0 + return data['data']['image_key'] + + @staticmethod + def report_run(feishu_groups: List, experiment_group: str, project: str, success_ratio: str, + config_protocol: Dict, + imgfile: Union[str, Path], track_url: str, workflow_url: str, + app_id: str = "", app_secret: str = ""): + app_id = os.environ.get("FEISHU_APP_ID", app_id) + app_secret = os.environ.get("FEISHU_APP_SECRET", app_secret) + now = datetime.datetime.now() + img_key = FeishuReporter._upload_image(imgfile, app_id=app_id, app_secret=app_secret) + + message = FEISHU_MESSAGE.copy() + + message["elements"][0]["fields"][0]["text"]["content"] = \ + message["elements"][0]["fields"][0]["text"]["content"] % now.strftime("%Y-%m-%d %H:%M:%S") + message["elements"][0]["fields"][1]["text"]["content"] = \ + message["elements"][0]["fields"][1]["text"]["content"] % experiment_group + message["elements"][0]["fields"][3]["text"]["content"] = \ + message["elements"][0]["fields"][3]["text"]["content"] % project + message["elements"][0]["fields"][4]["text"]["content"] = \ + message["elements"][0]["fields"][4]["text"]["content"] % success_ratio + message["elements"][1]["fields"][0]["text"]["content"] = \ + message["elements"][1]["fields"][0]["text"]["content"] % json.dumps(config_protocol, indent=4) + message["elements"][2]["img_key"] = img_key + message["elements"][5]["text"]["content"] = message["elements"][5]["text"]["content"] % ( + track_url, workflow_url) + + for feishu_group in feishu_groups: + requests.post(feishu_group, + json={"msg_type": "interactive", "card": message}) diff --git a/evals/reporters/WandB.py b/evals/reporters/WandB.py new file mode 100644 index 0000000000..7fa45d4c25 --- /dev/null +++ b/evals/reporters/WandB.py @@ -0,0 +1,43 @@ +from pathlib import Path +from typing import Dict, Union, List +import traceback + +import pandas as pd + +try: + import wandb +except: + print("No wandb found!") + + +class WandBReporter: + @staticmethod + def report_run(config_logger: Dict, metric_data: pd.DataFrame, step: int = -1): + logger_data = {} + + logger_data[f"correlation_ligand_sidechain"] = wandb.Plotly(fig) + + wandb_config = config_logger.get("wandb", {}).copy() + wandb_config["name"] = config_logger["name"] + wandb_config["group"] = config_logger["group"] + wandb_config["id"] = config_logger["id"] + wandb.login(key=wandb_config.pop('key')) + + try: + run = wandb.init(**wandb_config) + except: + traceback.print_exc() + wandb_config["mode"] = "offline" + run = wandb.init(**wandb_config) + sampler_metric_wb = wandb.Table(dataframe=metric_data) + logger_data["sampler_metrics"] = sampler_metric_wb + + if step >= 0: + wandb.log(data=logger_data, step=step) + else: + wandb.log(data=logger_data) + wandb.finish() + + @staticmethod + def report_summary(): + pass diff --git a/evals/reporters/__init__.py b/evals/reporters/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From d4ad1fdd8ca08f908849dfb8a252dc155bd19900 Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Fri, 29 Dec 2023 18:50:01 +0800 Subject: [PATCH 08/22] bugfixes on example showcase --- evals/cli/oaieval.py | 29 +++++++++-- evals/elsuite/table_extract.py | 13 +++-- .../data/00_scipaper_affinity/samples.jsonl | 4 +- evals/reporters/DPTracking.py | 50 +++++++++---------- examples/config_logger.json | 7 +++ 5 files changed, 67 insertions(+), 36 deletions(-) create mode 100644 examples/config_logger.json diff --git a/evals/cli/oaieval.py b/evals/cli/oaieval.py index a17898b164..0841672fb4 100644 --- a/evals/cli/oaieval.py +++ b/evals/cli/oaieval.py @@ -4,9 +4,11 @@ import argparse import json import logging +import pickle import re import shlex import sys +from io import StringIO from pathlib import Path from typing import Any, Mapping, Optional, Union, cast @@ -237,22 +239,41 @@ def to_number(x: str) -> Union[int, float, str]: if args.mlops: import pandas as pd + import plotly.express as px + + recorder.flush_events() with open(record_path, "r") as f: events_df = pd.read_json(f, lines=True) + print(events_df) run_config = events_df.loc[0, "spec"] - matches_df = events_df[events_df.type == "match"].reset_index(drop=True) + matches_df = events_df[events_df["type"] == "match"].reset_index(drop=True) matches_df = matches_df.join(pd.json_normalize(matches_df.data)) matches_df["doi"] = [re.sub("__([0-9]+)__", "(\1)", Path(f).stem).replace("_", "/") for f in matches_df["file_name"]] # TODO: compare on different completion_functions - accuracy_by_type_and_file = matches_df.groupby(["jobtype", "doi"])['correct'].mean() - accuracy_by_type = matches_df.groupby(["jobtype"])['correct'].mean() + accuracy_by_type_and_file = matches_df.groupby(["jobtype", "doi"])['correct'].mean().reset_index() + accuracy_by_type = matches_df.groupby(["jobtype"])['correct'].mean().to_dict() + + print(accuracy_by_type_and_file) + + logger_data = { + **accuracy_by_type, + "Accuracy": px.box(accuracy_by_type_and_file, x="jobtype", y="correct", color="jobtype", title="Accuracy by jobtype and model"), + } - logger_data = {} + for doi, df in matches_df.groupby("doi"): + logger_data[f"{doi.replace('/', '_')}/context:match"] = df[df["jobtype"] != "match_all"][["correct", "expected", "picked", "jobtype"]] + match_all_data = df[df["jobtype"] == "match_all"].iloc[0, :] + logger_data[f"{doi.replace('/', '_')}/context:truth"] = pd.read_csv(StringIO(match_all_data["expected"]), header=[0, 1]) + logger_data[f"{doi.replace('/', '_')}/context:extract"] = pd.read_csv(StringIO(match_all_data["picked"]), header=[0, 1]) \ + if df["jobtype"].iloc[0] != "match_all" else match_all_data["picked"] + pickle.dump(logger_data, open("logger_data.pkl", "wb")) config_logger = json.load(open(args.mlops, 'r')) + if "name" not in config_logger.keys(): + config_logger["name"] = f"{run_spec.run_id}_{args.completion_fn}_{args.eval}" if "dp_mlops" in config_logger: from evals.reporters.DPTracking import DPTrackingReporter DPTrackingReporter.report_run(config_logger, run_config, logger_data, step=0) diff --git a/evals/elsuite/table_extract.py b/evals/elsuite/table_extract.py index 43aef80c59..e3aaa98d5d 100644 --- a/evals/elsuite/table_extract.py +++ b/evals/elsuite/table_extract.py @@ -206,7 +206,8 @@ def eval_sample(self, sample, rng): correct_answer = parse_table_multiindex(pd.read_csv(sample.answerfile_name, header=[0, 1]).astype(str)) correct_answer = correct_answer.sort_values(by=("Compound", "")) - correct_str = correct_answer.to_csv() + correct_answer.to_csv("temp.csv", index=False) + correct_str = open("temp.csv", 'r').read() try: if "csv" in prompt: @@ -222,7 +223,7 @@ def eval_sample(self, sample, rng): table = pd.DataFrame(json.loads(code_content)) else: table = pd.DataFrame() - table = parse_table_multiindex(table).sort_values(by="Compound") + table = parse_table_multiindex(table).sort_values(by=("Compound", "")) except: record_match( correct=False, @@ -233,7 +234,9 @@ def eval_sample(self, sample, rng): ) return - table.to_csv(sample.answerfile_name.replace(".csv", "_output.csv")) + answerfile_out = sample.answerfile_name.replace(".csv", "_output.csv") + table.to_csv(answerfile_out, index=False) + picked_str = open(answerfile_out, 'r').read() match_all = True for field in sample.compare_fields: @@ -267,12 +270,12 @@ def eval_sample(self, sample, rng): expected=correct_value, picked=sample_value, file_name=sample.file_name, - jobtype="match_value" + jobtype=field if type(field) == str else field[0] ) record_match( correct=match_all, expected=correct_str, - picked=table.to_string(), + picked=picked_str, file_name=sample.file_name, jobtype="match_all" ) diff --git a/evals/registry/data/00_scipaper_affinity/samples.jsonl b/evals/registry/data/00_scipaper_affinity/samples.jsonl index 4d1ea75b61..728b661ad4 100644 --- a/evals/registry/data/00_scipaper_affinity/samples.jsonl +++ b/evals/registry/data/00_scipaper_affinity/samples.jsonl @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c8151e952830ec4993a74d682cadf973f294d9faaa4cdae0908ca7c4e62dcc73 -size 1304 +oid sha256:937c850d5e2c3f0114f1e8d4a8789d2c80f7113440985c2cdee09793e546c192 +size 913 diff --git a/evals/reporters/DPTracking.py b/evals/reporters/DPTracking.py index 04aa695ef9..485170711c 100644 --- a/evals/reporters/DPTracking.py +++ b/evals/reporters/DPTracking.py @@ -11,7 +11,6 @@ import pandas as pd import aim from PIL import Image -from rdkit import Chem class DPTrackingReporter: @@ -22,25 +21,25 @@ def _convert_logger_table(df: pd.DataFrame) -> aim.Table: return aim.Table(aim_df) for col in aim_df.columns: i = 0 - while not aim_df.loc[i, col]: + while not aim_df[col].iloc[i]: i += 1 if i == aim_df.shape[0]: i = 0 break - data0 = aim_df.loc[i, col] - if isinstance(data0, Chem.Mol): - molfiles = [] - tmpdir = f"aim-tmp-{uuid.uuid4().hex}" - Path(tmpdir).mkdir(exist_ok=True, parents=True) - for i, mol in enumerate(aim_df[col]): - if mol: - molfile = f"{tmpdir}/{i}.sdf" - Chem.MolToMolFile(mol, molfile) - molfiles.append(molfile) - else: - molfiles.append(None) - aim_df[col] = [aim.Molecule(molfile) if molfile else None for molfile in molfiles] - elif isinstance(data0, Image.Image): + data0 = aim_df[col].iloc[i] + # if isinstance(data0, Chem.Mol): + # molfiles = [] + # tmpdir = f"aim-tmp-{uuid.uuid4().hex}" + # Path(tmpdir).mkdir(exist_ok=True, parents=True) + # for i, mol in enumerate(aim_df[col]): + # if mol: + # molfile = f"{tmpdir}/{i}.sdf" + # Chem.MolToMolFile(mol, molfile) + # molfiles.append(molfile) + # else: + # molfiles.append(None) + # aim_df[col] = [aim.Molecule(molfile) if molfile else None for molfile in molfiles] + if isinstance(data0, Image.Image): imgfiles = [] tmpdir = f"aim-tmp-{uuid.uuid4().hex}" Path(tmpdir).mkdir(exist_ok=True, parents=True) @@ -69,6 +68,8 @@ def _convert_logger_data(v: Any) -> Any: return DPTrackingReporter._convert_logger_table(v) if type(v) in [np.ndarray, list]: return aim.Distribution(v) + if type(v) == str: + return aim.Text(v) return v @staticmethod @@ -76,27 +77,26 @@ def report_run(config_logger: Dict, config_run: Dict = {}, logger_data: Dict = { dp_mlops_config = config_logger["dp_mlops"] # Experiment Tracking - os.environ["AIM_ACCESS_TOKEN"] = dp_mlops_config["aim_personal_token"] - print('debug report_sampler: run_hash', config_logger["hash"], datetime.now()) - # os.environ['AIM_UNSAFE_SESSION_COOKIE'] = config_logger["hash"] + if "aim_personal_token" in dp_mlops_config.keys(): + os.environ["AIM_ACCESS_TOKEN"] = dp_mlops_config["aim_personal_token"] run = aim.Run( experiment=config_logger["project"], run_hash=config_logger.get("hash", None), repo=dp_mlops_config["aim_repo"] ) - # run = Run(experiment=config_logger["project"], repo=dp_mlops_config["aim_repo"]) run.name = config_logger["name"] - run.hparams["config"] = config_run + run["config"] = config_run for tag in set([config_logger["name"]] + dp_mlops_config.get("tags", [])): if tag and tag.lower() not in [t.lower() for t in run.props.tags]: print(tag.lower(), run.props.tags) run.add_tag(tag.lower()) - DPTrackingReporter._convert_logger_data(logger_data) + logger_data_aim = {key: DPTrackingReporter._convert_logger_data(value) for key, value in logger_data.items()} - for key, value in logger_data.items(): - if "/" not in key or "kcal/mol" in key: - run.track(value, name=key, context={}) + for key, value in logger_data_aim.items(): + print(key, type(value)) + if "/" not in key or "kcal/mol" in key or "10.1021/" in key or "10.1016/" in key: + run.track(value, name=key) else: key, context_str = key.split("/") context_dict = {k: v for k, v in [kv.split(":") for kv in context_str.split(",")]} diff --git a/examples/config_logger.json b/examples/config_logger.json new file mode 100644 index 0000000000..dcfee21077 --- /dev/null +++ b/examples/config_logger.json @@ -0,0 +1,7 @@ +{ + "name": "20231231-unifinder-poc", + "project": "Uni-Finder/Benchmark", + "dp_mlops":{ + "aim_repo": "aim://tracking-api.mlops.dp.tech:443" + } +} From 30efde8e87844fa50e65702d5c1f7baca57d7984 Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Wed, 10 Jan 2024 04:49:48 +0800 Subject: [PATCH 09/22] add rag to openai native completion_fns --- evals/completion_fns/openai.py | 58 +++++++++++++++++------ evals/completion_fns/retrieval_native.py | 53 ++++----------------- evals/completion_fns/uni_finder.py | 2 - evals/utils/api_utils.py | 59 ++++++++++++++++++++++++ 4 files changed, 111 insertions(+), 61 deletions(-) diff --git a/evals/completion_fns/openai.py b/evals/completion_fns/openai.py index ed50818630..8cd95dff04 100644 --- a/evals/completion_fns/openai.py +++ b/evals/completion_fns/openai.py @@ -15,6 +15,7 @@ from evals.utils.api_utils import ( openai_chat_completion_create_retrying, openai_completion_create_retrying, + openai_rag_completion_create_retrying ) @@ -46,6 +47,15 @@ def get_completions(self) -> list[str]: return completions +class RetrievalCompletionResult(CompletionResult): + def __init__(self, response: str, prompt: Any) -> None: + self.response = response + self.prompt = prompt + + def get_completions(self) -> list[str]: + return [self.response.strip()] + + class OpenAICompletionFn(CompletionFn): def __init__( self, @@ -81,13 +91,22 @@ def __call__( openai_create_prompt: OpenAICreatePrompt = prompt.to_formatted_prompt() - result = openai_completion_create_retrying( - OpenAI(api_key=self.api_key, base_url=self.api_base), - model=self.model, - prompt=openai_create_prompt, - **{**kwargs, **self.extra_options}, - ) - result = OpenAICompletionResult(raw_data=result, prompt=openai_create_prompt) + if "file_name" not in kwargs: + result = openai_completion_create_retrying( + OpenAI(api_key=self.api_key, base_url=self.api_base), + model=self.model, + prompt=openai_create_prompt, + **{**kwargs, **self.extra_options}, + ) + result = OpenAICompletionResult(raw_data=result, prompt=openai_create_prompt) + else: + answer = openai_rag_completion_create_retrying( + OpenAI(api_key=self.api_key, base_url=self.api_base), + model=self.model, + instructions=kwargs.get("instructions", ""), + file_name=kwargs.get("file_name", ""), + ) + result = RetrievalCompletionResult(answer, prompt=openai_create_prompt) record_sampling(prompt=result.prompt, sampled=result.get_completions()) return result @@ -126,12 +145,23 @@ def __call__( openai_create_prompt: OpenAICreateChatPrompt = prompt.to_formatted_prompt() - result = openai_chat_completion_create_retrying( - OpenAI(api_key=self.api_key, base_url=self.api_base), - model=self.model, - messages=openai_create_prompt, - **{**kwargs, **self.extra_options}, - ) - result = OpenAIChatCompletionResult(raw_data=result, prompt=openai_create_prompt) + if "file_name" not in kwargs: + result = openai_chat_completion_create_retrying( + OpenAI(api_key=self.api_key, base_url=self.api_base), + model=self.model, + messages=openai_create_prompt, + **{**kwargs, **self.extra_options}, + ) + result = OpenAIChatCompletionResult(raw_data=result, prompt=openai_create_prompt) + else: + chatmodel_to_apimodel = lambda x: "gpt-3.5-turbo-1106" if x.startswith("gpt-3.5-turbo-") else "gpt-4-1106-preview" if x.startswith("gpt-4-") else "" + answer = openai_rag_completion_create_retrying( + OpenAI(api_key=self.api_key, base_url=self.api_base), + model=chatmodel_to_apimodel(self.model), + instructions=kwargs.get("instructions", ""), + file_name=kwargs.get("file_name", ""), + prompt=CompletionPrompt(raw_prompt=openai_create_prompt).to_formatted_prompt() + ) + result = RetrievalCompletionResult(answer, prompt=openai_create_prompt) record_sampling(prompt=result.prompt, sampled=result.get_completions()) return result diff --git a/evals/completion_fns/retrieval_native.py b/evals/completion_fns/retrieval_native.py index eb5b801816..f06e2da423 100644 --- a/evals/completion_fns/retrieval_native.py +++ b/evals/completion_fns/retrieval_native.py @@ -10,19 +10,12 @@ from openai import OpenAI client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) -import pandas as pd from evals.api import CompletionFn, CompletionResult +from evals.completion_fns.openai import RetrievalCompletionResult from evals.prompt.base import ChatCompletionPrompt, CompletionPrompt from evals.record import record_sampling - - -class RetrievalCompletionResult(CompletionResult): - def __init__(self, response: str) -> None: - self.response = response - - def get_completions(self) -> list[str]: - return [self.response.strip()] +from evals.utils.api_utils import openai_rag_completion_create_retrying class OpenAIRetrievalCompletionFn(CompletionFn): @@ -56,42 +49,12 @@ def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> RetrievalCo assert "file_name" in kwargs, "Must provide a file_name to retrieve." - file = client.files.create(file=open(kwargs["file_name"], "rb"), purpose='assistants') - - # Create an Assistant (Note model="gpt-3.5-turbo-1106" instead of "gpt-4-1106-preview") - assistant = client.beta.assistants.create( - name="File Assistant", - instructions=self.instructions, + answer = openai_rag_completion_create_retrying( + client, model=self.model, - tools=[{"type": "retrieval"}], - file_ids=[file.id] + instructions=self.instructions, + file_name=kwargs.get("file_name", ""), + prompt=CompletionPrompt(raw_prompt=prompt).to_formatted_prompt(), ) - - # Create a Thread - thread = client.beta.threads.create() - - # Add a Message to a Thread - print(prompt) - message = client.beta.threads.messages.create(thread_id=thread.id, role="user", - content=prompt - ) - - # Run the Assistant - run = client.beta.threads.runs.create(thread_id=thread.id, assistant_id=assistant.id) - print(run.model_dump_json(indent=4)) - - # If run is 'completed', get messages and print - while True: - # Retrieve the run status - run_status = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) - time.sleep(10) - if run_status.status == 'completed': - messages = client.beta.threads.messages.list(thread_id=thread.id) - answer = messages.data[0].content[0].text.value - break - else: - ### sleep again - time.sleep(2) - print(answer) record_sampling(prompt=prompt, sampled=answer) - return RetrievalCompletionResult(answer) + return RetrievalCompletionResult(answer, prompt=prompt) diff --git a/evals/completion_fns/uni_finder.py b/evals/completion_fns/uni_finder.py index 32e61f324f..2169ba00a9 100644 --- a/evals/completion_fns/uni_finder.py +++ b/evals/completion_fns/uni_finder.py @@ -68,8 +68,6 @@ def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> UniFinderCo pdf_id = response['pdf_token'] # 获得pdf的id,表示上传成功,后续可以使用这个id来指定pdf pdf_token.append(pdf_id) - assert "file_name" in kwargs, "Must provide a file_name to retrieve." - url = f"{self.api_base}/api/external/chatpdf" payload = { diff --git a/evals/utils/api_utils.py b/evals/utils/api_utils.py index ae6d34ae30..f6592919f9 100644 --- a/evals/utils/api_utils.py +++ b/evals/utils/api_utils.py @@ -4,6 +4,7 @@ import concurrent import logging import os +import time import backoff import openai @@ -70,3 +71,61 @@ def openai_chat_completion_create_retrying(client: OpenAI, *args, **kwargs): logging.warning(result) raise openai.error.APIError(result["error"]) return result + + +@backoff.on_exception( + wait_gen=backoff.expo, + exception=( + openai.RateLimitError, + openai.APIConnectionError, + openai.APITimeoutError, + openai.InternalServerError, + ), + max_value=60, + factor=1.5, +) +def openai_rag_completion_create_retrying(client: OpenAI, *args, **kwargs): + """ + Helper function for creating a RAG completion. + `args` and `kwargs` match what is accepted by `openai.ChatCompletion.create`. + """ + + file = client.files.create(file=open(kwargs["file_name"], "rb"), purpose='assistants') + + # Create an Assistant (Note model="gpt-3.5-turbo-1106" instead of "gpt-4-1106-preview") + assistant = client.beta.assistants.create( + name="File Assistant", + instructions=kwargs.get("instructions", ""), + model=kwargs.get("model", "gpt-3.5-turbo-1106"), + tools=[{"type": "retrieval"}], + file_ids=[file.id] + ) + + # Create a Thread + thread = client.beta.threads.create() + + # Add a Message to a Thread + message = client.beta.threads.messages.create(thread_id=thread.id, role="user", + content=kwargs.get("prompt", "") + ) + + # Run the Assistant + run = client.beta.threads.runs.create(thread_id=thread.id, assistant_id=assistant.id) + + # If run is 'completed', get messages and print + while True: + # Retrieve the run status + run_status = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) + time.sleep(10) + if run_status.status == 'completed': + messages = client.beta.threads.messages.list(thread_id=thread.id) + answer = messages.data[0].content[0].text.value + break + else: + ### sleep again + time.sleep(2) + + # if "error" in result: + # logging.warning(result) + # raise openai.error.APIError(result["error"]) + return answer From 474d63bd229b1fd01b2c4189d8319e2bc04240e2 Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Wed, 10 Jan 2024 15:35:29 +0800 Subject: [PATCH 10/22] add RAG for match, modelgraded_classify, table_extract evals --- evals/completion_fns/openai.py | 2 +- evals/completion_fns/uni_finder.py | 16 +-- evals/elsuite/modelgraded/rag_classify.py | 131 ++++++++++++++++++ evals/elsuite/rag_match.py | 119 ++++++++++++++++ ...{table_extract.py => rag_table_extract.py} | 67 ++------- 5 files changed, 266 insertions(+), 69 deletions(-) create mode 100644 evals/elsuite/modelgraded/rag_classify.py create mode 100644 evals/elsuite/rag_match.py rename evals/elsuite/{table_extract.py => rag_table_extract.py} (77%) diff --git a/evals/completion_fns/openai.py b/evals/completion_fns/openai.py index 8cd95dff04..b57570e0e9 100644 --- a/evals/completion_fns/openai.py +++ b/evals/completion_fns/openai.py @@ -154,7 +154,7 @@ def __call__( ) result = OpenAIChatCompletionResult(raw_data=result, prompt=openai_create_prompt) else: - chatmodel_to_apimodel = lambda x: "gpt-3.5-turbo-1106" if x.startswith("gpt-3.5-turbo-") else "gpt-4-1106-preview" if x.startswith("gpt-4-") else "" + chatmodel_to_apimodel = lambda x: "gpt-3.5-turbo-1106" if x.startswith("gpt-3.5-turbo") else "gpt-4-1106-preview" if x.startswith("gpt-4") else "" answer = openai_rag_completion_create_retrying( OpenAI(api_key=self.api_key, base_url=self.api_base), model=chatmodel_to_apimodel(self.model), diff --git a/evals/completion_fns/uni_finder.py b/evals/completion_fns/uni_finder.py index 2169ba00a9..b499a39ea9 100644 --- a/evals/completion_fns/uni_finder.py +++ b/evals/completion_fns/uni_finder.py @@ -6,16 +6,9 @@ import requests from typing import Any, Optional, Union -import numpy as np -from openai import OpenAI - -import pandas as pd - +from evals.prompt.base import CompletionPrompt from evals.api import CompletionFn, CompletionResult from evals.record import record_sampling -from evals.utils.api_utils import ( - request_with_timeout -) class UniFinderCompletionResult(CompletionResult): @@ -23,7 +16,7 @@ def __init__(self, response: str) -> None: self.response = response def get_completions(self) -> list[str]: - return [self.response.strip()] + return [self.response.strip()] if self.response else ["Unknown"] class UniFinderCompletionFn(CompletionFn): @@ -66,10 +59,14 @@ def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> UniFinderCo } response = requests.post(url, data=data, files=files).json() pdf_id = response['pdf_token'] # 获得pdf的id,表示上传成功,后续可以使用这个id来指定pdf + print("############# pdf_id ##############", pdf_id) pdf_token.append(pdf_id) url = f"{self.api_base}/api/external/chatpdf" + if type(prompt) == list: + prompt = CompletionPrompt(prompt).to_formatted_prompt() + payload = { "model_engine": self.model, "pdf_token": pdf_token, @@ -78,6 +75,5 @@ def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> UniFinderCo } response = requests.post(url, json=payload).json() answer = response['answer'] - print(answer) record_sampling(prompt=prompt, sampled=answer) return UniFinderCompletionResult(answer) diff --git a/evals/elsuite/modelgraded/rag_classify.py b/evals/elsuite/modelgraded/rag_classify.py new file mode 100644 index 0000000000..aa471ebd39 --- /dev/null +++ b/evals/elsuite/modelgraded/rag_classify.py @@ -0,0 +1,131 @@ +""" +Generic eval that uses a prompt + classification. +""" +from collections import Counter +from random import Random +from typing import Any, Optional, Union + +import evals +import evals.record +from evals.elsuite.modelgraded.classify_utils import classify, sample_and_concat_n_completions +from evals.elsuite.rag_match import get_rag_dataset +from evals.elsuite.utils import PromptFn, scrub_formatting_from_prompt + + +class RAGModelBasedClassify(evals.Eval): + def __init__( + self, + modelgraded_spec: str, + *args, + modelgraded_spec_args: Optional[dict[str, dict[str, str]]] = None, + sample_kwargs: Optional[dict[str, Any]] = None, + eval_kwargs: Optional[dict[str, Any]] = None, + multicomp_n: Union[int, str] = 1, + eval_type: Optional[str] = None, + match_fn: Optional[str] = None, + metaeval: bool = False, + **kwargs, + ): + super().__init__(*args, **kwargs) + # treat last completion_fn as eval_completion_fn + self.eval_completion_fn = self.completion_fns[-1] + if len(self.completion_fns) > 1: + self.completion_fns = self.completion_fns[:-1] + n_models = len(self.completion_fns) + self.sample_kwargs = {"max_tokens": 1024} + self.sample_kwargs.update(sample_kwargs or {}) + self.eval_kwargs = {"max_tokens": 1024} + self.eval_kwargs.update(eval_kwargs or {}) + self.metaeval = metaeval + self.modelgraded_spec_args = modelgraded_spec_args or {} + self.eval_type = eval_type + self.match_fn = match_fn + if multicomp_n == "from_models": + assert n_models > 1 + self.multicomp_n = n_models + else: + assert isinstance(multicomp_n, int) + self.multicomp_n = multicomp_n + if len(self.completion_fns) > 1: + assert self.multicomp_n == n_models + + self.mg = self.registry.get_modelgraded_spec(modelgraded_spec) + + def eval_sample(self, test_sample: dict, rng: Random) -> None: + """Evaluate a single sample. + + Recorded metrics are always: one of the self.choice_strings, or "__invalid__". + """ + # process test_sample + sample_file_dict = {key: value for key, value in test_sample.items() if key.startswith("file")} + test_sample = {key: value for key, value in test_sample.items() if not key.startswith("file")} + for k in self.mg.input_outputs: + test_sample[k] = scrub_formatting_from_prompt(test_sample[k]) + + # run policy completions + completions = {} + for k, v in self.mg.input_outputs.items(): + if v in test_sample: # test_sample already has completion, skip. + continue + + if self.multicomp_n > 1: + completion = sample_and_concat_n_completions( + self.completion_fns, + prompt=test_sample[k], + template_i=self.mg.output_template, + sample_kwargs={**self.sample_kwargs, "completion_kwargs": sample_file_dict}, + n=self.multicomp_n, + ) + else: + get_input_completion = PromptFn( + test_sample[k], completion_fn=self.completion_fn, **{**self.sample_kwargs, "completion_kwargs": sample_file_dict} + ) + completion, _ = get_input_completion() + completions[v] = completion + + # run modelgraded eval + metrics = {} + choice, info = classify( + mg=self.mg, + completion_fn=self.eval_completion_fn, + completion_kwargs=self.eval_kwargs, + eval_type=self.eval_type, + n=self.multicomp_n, + match_fn=self.match_fn, + format_kwargs={**completions, **test_sample, **self.modelgraded_spec_args}, + ) + metrics.update(dict(choice=choice, score=info["score"])) + + # run metaeval if requested + if self.metaeval: + assert "choice" in test_sample + metrics["metascore"] = choice == test_sample["choice"] + + evals.record.record_metrics(**metrics) + + return choice + + def run(self, recorder): + samples = get_rag_dataset(self._prefix_registry_path(self.samples_jsonl).as_posix()) + + self.eval_all_samples(recorder, samples) + record_metrics = {} + + all_sample_metrics = recorder.get_metrics() + if not all_sample_metrics: + return record_metrics + + # record the counts + choices = [m["choice"] for m in all_sample_metrics] + counts = dict(Counter(choices)) + record_metrics.update({f"counts/{k}": v for k, v in counts.items()}) + + # record the scores + scores = [m["score"] for m in all_sample_metrics if m["score"] is not None] + if scores: + record_metrics["score"] = sum(scores) / len(scores) + metascores = [m["metascore"] for m in all_sample_metrics if "metascore" in m] + if metascores: + record_metrics["metascore"] = sum(metascores) / len(metascores) + + return record_metrics diff --git a/evals/elsuite/rag_match.py b/evals/elsuite/rag_match.py new file mode 100644 index 0000000000..1678582730 --- /dev/null +++ b/evals/elsuite/rag_match.py @@ -0,0 +1,119 @@ +import os +from pathlib import Path +from typing import Any + +import oss2 +from oss2.credentials import EnvironmentVariableCredentialsProvider + +import evals +import evals.metrics +from evals.api import CompletionFn +from evals.prompt.base import is_chat_prompt + + +def init_oss(): + """ + Initialize OSS client. + """ + # Please set OSS_ACCESS_KEY_ID & OSS_ACCESS_KEY_SECRET in your environment variables. + auth = oss2.ProviderAuth(EnvironmentVariableCredentialsProvider()) + + # 设置 Endpoint + endpoint = 'https://oss-cn-beijing.aliyuncs.com' + + # 设置 Bucket + bucket_name = 'dp-filetrans-bj' + bucket = oss2.Bucket(auth, endpoint, bucket_name) + + return bucket + + +def get_rag_dataset(samples_jsonl: str) -> list[dict]: + bucket = init_oss() + raw_samples = evals.get_jsonl(samples_jsonl) + + for raw_sample in raw_samples: + for ftype in ["", "answer"]: + if f"{ftype}file_name" not in raw_sample and f"{ftype}file_link" not in raw_sample: + continue + if f"{ftype}file_name" in raw_sample: + oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_name"]) + raw_sample[f"{ftype}file_link"] = "https://dp-filetrans-bj.oss-cn-beijing.aliyuncs.com/" + oss_file + + exists = bucket.object_exists(oss_file) + if exists: + print(f"文件 {oss_file} 已存在于 OSS 中。") + else: + # 上传文件 + bucket.put_object_from_file(oss_file, raw_sample[f"{ftype}file_name"]) + print(f"文件 {oss_file} 已上传到 OSS。") + elif f"{ftype}file_link" in raw_sample: + local_file = raw_sample[f"{ftype}file_name"] if f"{ftype}file_name" in raw_sample else os.path.basename( + raw_sample[f"{ftype}file_link"]) + oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_link"]) + if not os.path.exists(local_file): + if bucket.object_exists(oss_file): + # 从 OSS 下载文件 + Path(local_file).parent.mkdir(parents=True, exist_ok=True) + bucket.get_object_to_file(oss_file, local_file) + return raw_samples + + +class RAGMatch(evals.Eval): + def __init__( + self, + completion_fns: list[CompletionFn], + samples_jsonl: str, + *args, + max_tokens: int = 500, + num_few_shot: int = 0, + few_shot_jsonl: str = None, + **kwargs, + ): + super().__init__(completion_fns, *args, **kwargs) + assert len(completion_fns) == 1, "Match only supports one completion fn" + self.max_tokens = max_tokens + self.samples_jsonl = samples_jsonl + self.num_few_shot = num_few_shot + if self.num_few_shot > 0: + assert few_shot_jsonl is not None, "few shot requires few shot sample dataset" + self.few_shot_jsonl = few_shot_jsonl + self.few_shot = evals.get_jsonl(self._prefix_registry_path(self.few_shot_jsonl)) + + def eval_sample(self, sample: Any, *_): + assert isinstance(sample, dict), "sample must be a dict" + assert "input" in sample, "sample must have an 'input' key" + assert "ideal" in sample, "sample must have an 'ideal' key" + assert isinstance(sample["ideal"], str) or isinstance( + sample["ideal"], list + ), "sample['ideal'] must be a string or list of strings" + + prompt = sample["input"] + if self.num_few_shot > 0: + assert is_chat_prompt(sample["input"]), "few shot requires chat prompt" + prompt = sample["input"][:-1] + for s in self.few_shot[: self.num_few_shot]: + prompt += s["sample"] + prompt += sample["input"][-1:] + + result = self.completion_fn( + prompt=prompt, + temperature=0.0, + **{k: v for k, v in sample.items() if k not in ["input", "ideal"]} + ) + sampled = result.get_completions()[0] + + return evals.record_and_check_match( + prompt=prompt, + sampled=sampled, + expected=sample["ideal"], + ) + + def run(self, recorder): + samples = get_rag_dataset(self._prefix_registry_path(self.samples_jsonl).as_posix()) + self.eval_all_samples(recorder, samples) + events = recorder.get_events("match") + return { + "accuracy": evals.metrics.get_accuracy(events), + "boostrap_std": evals.metrics.get_bootstrap_accuracy_std(events), + } diff --git a/evals/elsuite/table_extract.py b/evals/elsuite/rag_table_extract.py similarity index 77% rename from evals/elsuite/table_extract.py rename to evals/elsuite/rag_table_extract.py index e3aaa98d5d..3099d08b50 100644 --- a/evals/elsuite/table_extract.py +++ b/evals/elsuite/rag_table_extract.py @@ -1,20 +1,16 @@ from io import StringIO import json -import os -from pathlib import Path import re from typing import List, Optional, Tuple, Union -import oss2 -from oss2.credentials import EnvironmentVariableCredentialsProvider - import pandas as pd from pydantic import BaseModel import evals import evals.metrics from evals.api import CompletionFn +from evals.elsuite.rag_match import get_rag_dataset from evals.record import RecorderBase, record_match code_pattern = r"```[\s\S]*?\n([\s\S]+?)\n```" @@ -51,23 +47,6 @@ def parse_table_multiindex(table: pd.DataFrame) -> pd.DataFrame: return df -def init_oss(): - """ - Initialize OSS client. - """ - # Please set OSS_ACCESS_KEY_ID & OSS_ACCESS_KEY_SECRET in your environment variables. - auth = oss2.ProviderAuth(EnvironmentVariableCredentialsProvider()) - - # 设置 Endpoint - endpoint = 'https://oss-cn-beijing.aliyuncs.com' - - # 设置 Bucket - bucket_name = 'dp-filetrans-bj' - bucket = oss2.Bucket(auth, endpoint, bucket_name) - - return bucket - - class FileSample(BaseModel): file_name: Optional[str] file_link: Optional[str] @@ -76,39 +55,6 @@ class FileSample(BaseModel): compare_fields: List[Union[str, Tuple]] -def get_dataset(data_jsonl: str) -> list[FileSample]: - bucket = init_oss() - raw_samples = evals.get_jsonl(data_jsonl) - - for raw_sample in raw_samples: - for ftype in ["", "answer"]: - if f"{ftype}file_name" in raw_sample: - oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_name"]) - raw_sample[f"{ftype}file_link"] = "https://dp-filetrans-bj.oss-cn-beijing.aliyuncs.com/" + oss_file - - exists = bucket.object_exists(oss_file) - if exists: - print(f"文件 {oss_file} 已存在于 OSS 中。") - else: - # 上传文件 - bucket.put_object_from_file(oss_file, raw_sample[f"{ftype}file_name"]) - print(f"文件 {oss_file} 已上传到 OSS。") - elif f"{ftype}file_link" in raw_sample: - local_file = raw_sample[f"{ftype}file_name"] if f"{ftype}file_name" in raw_sample else os.path.basename( - raw_sample[f"{ftype}file_link"]) - oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_link"]) - if not os.path.exists(local_file): - if bucket.object_exists(oss_file): - # 从 OSS 下载文件 - Path(local_file).parent.mkdir(parents=True, exist_ok=True) - bucket.get_object_to_file(oss_file, local_file) - raw_sample["compare_fields"] = [field if type(field) == str else tuple(field) for field in - raw_sample["compare_fields"]] - print(raw_samples) - samples = [FileSample(**raw_sample) for raw_sample in raw_samples] - return samples - - def fuzzy_compare(a: str, b: str) -> bool: """ Compare two strings with fuzzy matching. @@ -178,14 +124,14 @@ class TableExtract(evals.Eval): def __init__( self, completion_fns: list[CompletionFn], - dataset: str, + samples_jsonl: str, *args, instructions: Optional[str] = "", **kwargs, ): super().__init__(completion_fns, *args, **kwargs) assert len(completion_fns) < 3, "TableExtract only supports 3 completion fns" - self.dataset = dataset + self.samples_jsonl = samples_jsonl self.instructions = instructions def eval_sample(self, sample, rng): @@ -281,7 +227,12 @@ def eval_sample(self, sample, rng): ) def run(self, recorder: RecorderBase): - samples = get_dataset(self.dataset) + raw_samples = get_rag_dataset(self._prefix_registry_path(self.samples_jsonl).as_posix()) + for raw_sample in raw_samples: + raw_sample["compare_fields"] = [field if type(field) == str else tuple(field) for field in + raw_sample["compare_fields"]] + + samples = [FileSample(**raw_sample) for raw_sample in raw_samples] self.eval_all_samples(recorder, samples) return { "accuracy": evals.metrics.get_accuracy(recorder.get_events("match")), From d7213e0532e8e291610d98ed0d14197aec3019c9 Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Wed, 10 Jan 2024 15:35:51 +0800 Subject: [PATCH 11/22] add scipaper_tag2mol, scipaper_hasmol, scipaper_targets and markush2mol evals --- .../data/00_scipaper_affinity/samples.jsonl | 4 ++-- .../data/01_scipaper_hasmol/samples.jsonl | 3 +++ .../data/01_scipaper_tag2mol/samples.jsonl | 3 +++ evals/registry/data/02_markush2mol/samples.jsonl | 3 +++ .../data/03_scipaper_targets/samples.jsonl | 3 +++ evals/registry/evals/00_scipaper_affinity.yaml | 16 ++++++++-------- evals/registry/evals/01_scipaper_tag2mol.yaml | 8 ++++++++ evals/registry/evals/02_markush2mol.yaml | 8 ++++++++ evals/registry/evals/03_scipaper_targets.yaml | 12 ++++++++++++ 9 files changed, 50 insertions(+), 10 deletions(-) create mode 100644 evals/registry/data/01_scipaper_hasmol/samples.jsonl create mode 100644 evals/registry/data/01_scipaper_tag2mol/samples.jsonl create mode 100644 evals/registry/data/02_markush2mol/samples.jsonl create mode 100644 evals/registry/data/03_scipaper_targets/samples.jsonl create mode 100644 evals/registry/evals/01_scipaper_tag2mol.yaml create mode 100644 evals/registry/evals/02_markush2mol.yaml create mode 100644 evals/registry/evals/03_scipaper_targets.yaml diff --git a/evals/registry/data/00_scipaper_affinity/samples.jsonl b/evals/registry/data/00_scipaper_affinity/samples.jsonl index 728b661ad4..5f093873e3 100644 --- a/evals/registry/data/00_scipaper_affinity/samples.jsonl +++ b/evals/registry/data/00_scipaper_affinity/samples.jsonl @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:937c850d5e2c3f0114f1e8d4a8789d2c80f7113440985c2cdee09793e546c192 -size 913 +oid sha256:3f95fad1b8c6426a1acef3ae0190d4c81f81f5ebb0b8ed6fe3b39c683f77b7ee +size 5047 diff --git a/evals/registry/data/01_scipaper_hasmol/samples.jsonl b/evals/registry/data/01_scipaper_hasmol/samples.jsonl new file mode 100644 index 0000000000..2962fd16b8 --- /dev/null +++ b/evals/registry/data/01_scipaper_hasmol/samples.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac93b432f74bf7347b09b91a1c501738b65d0f2cfff6fdfdb6acc9786430ac86 +size 2151 diff --git a/evals/registry/data/01_scipaper_tag2mol/samples.jsonl b/evals/registry/data/01_scipaper_tag2mol/samples.jsonl new file mode 100644 index 0000000000..805bb85da8 --- /dev/null +++ b/evals/registry/data/01_scipaper_tag2mol/samples.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b49af9ced518fdc5477f229494a29e9b927349dddad630217c44eaa813af4abc +size 2131 diff --git a/evals/registry/data/02_markush2mol/samples.jsonl b/evals/registry/data/02_markush2mol/samples.jsonl new file mode 100644 index 0000000000..e6ddf2c625 --- /dev/null +++ b/evals/registry/data/02_markush2mol/samples.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0fda4937180b350dacffafdd5dcaa299fca3a1468f269be03ef4b158a50e8e02 +size 502 diff --git a/evals/registry/data/03_scipaper_targets/samples.jsonl b/evals/registry/data/03_scipaper_targets/samples.jsonl new file mode 100644 index 0000000000..d14eb83645 --- /dev/null +++ b/evals/registry/data/03_scipaper_targets/samples.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0efe816302535508a331db6feeb3cbbd36a95ea1c59d16f11e41a5e80a47daa8 +size 3482 diff --git a/evals/registry/evals/00_scipaper_affinity.yaml b/evals/registry/evals/00_scipaper_affinity.yaml index de865a86a9..236dde23ca 100644 --- a/evals/registry/evals/00_scipaper_affinity.yaml +++ b/evals/registry/evals/00_scipaper_affinity.yaml @@ -2,9 +2,9 @@ scipaper_affinity: id: scipaper_affinity.val.ab-v1 metrics: [accuracy] scipaper_affinity.val.json: - class: evals.elsuite.table_extract:TableExtract + class: evals.elsuite.rag_table_extract:TableExtract args: - dataset: 00_scipaper_affinity/samples.jsonl + samples_jsonl: 00_scipaper_affinity/samples.jsonl instructions: | Please give a complete list of names, affinities, target info (protein or cell line), and affinity types (chosen from IC50, EC50, TC50, GI50, Ki, Kd) of all inhibitors in the paper. If there are multiple tables, combine them. Don't give me reference. Output in json format. For example: ```json @@ -23,16 +23,16 @@ scipaper_affinity.val.json: ``` scipaper_affinity.val.csv: - class: evals.elsuite.table_extract:TableExtract + class: evals.elsuite.rag_table_extract:TableExtract args: - dataset: 00_scipaper_affinity/samples.jsonl + samples_jsonl: 00_scipaper_affinity/samples.jsonl instructions: | - Please give a complete list of affinities, target info (protein or cell line), and affinity types (chosen from IC50, EC50, TC50, GI50, Ki, Kd) of all inhibitors in the paper. + Please give a complete list of SMILES structures, affinities, target info (protein or cell line), and affinity types (chosen from IC50, EC50, TC50, GI50, Ki, Kd) of all compounds in the paper. Usually the coumpounds' tags are numbers. 1. Find all the tables with relevant information 2. Output in csv format with multiindex (Affinities, protein/cell line), write units not in header but in the value like "10.5 µM". Quote the value if it has comma! For example: ```csv - Compound,Name,Affinities,Affinities,Affinities,Affinities - ,,5HT1A (IC50),5HT1D (IC50),5HT-UT (IC50),5HT1E () - "5a","1,2-dimethyl Aspirin",2.0 nM,8.0 nM,12.6 nM,>1000 nM + Compound,Name,SMILES,Affinities,Affinities,Affinities,Affinities + ,,,5HT1A (IC50),5HT1D (IC50),5HT-UT (IC50),5HT1E () + "5a","Aspirin","CC(=O)Oc1ccccc1C(=O)O",2.0 nM,8.0 nM,12.6 nM,>1000 nM ``` 3. If there are multiple tables, concat them. Don't give me reference or using "...", give me complete table! \ No newline at end of file diff --git a/evals/registry/evals/01_scipaper_tag2mol.yaml b/evals/registry/evals/01_scipaper_tag2mol.yaml new file mode 100644 index 0000000000..556c73e1b9 --- /dev/null +++ b/evals/registry/evals/01_scipaper_tag2mol.yaml @@ -0,0 +1,8 @@ +scipaper_tag2mol: + id: scipaper_tag2mol.dev.v0 + metrics: [accuracy] + +scipaper_tag2mol.dev.v0: + class: evals.elsuite.rag_match:RAGMatch + args: + samples_jsonl: 01_scipaper_tag2mol/samples.jsonl \ No newline at end of file diff --git a/evals/registry/evals/02_markush2mol.yaml b/evals/registry/evals/02_markush2mol.yaml new file mode 100644 index 0000000000..d564774e9a --- /dev/null +++ b/evals/registry/evals/02_markush2mol.yaml @@ -0,0 +1,8 @@ +markush2mol: + id: markush2mol.dev.v0 + metrics: [accuracy] + +markush2mol.dev.v0: + class: evals.elsuite.basic.match:Match + args: + samples_jsonl: 02_markush2mol/samples.jsonl \ No newline at end of file diff --git a/evals/registry/evals/03_scipaper_targets.yaml b/evals/registry/evals/03_scipaper_targets.yaml new file mode 100644 index 0000000000..722a1eb7eb --- /dev/null +++ b/evals/registry/evals/03_scipaper_targets.yaml @@ -0,0 +1,12 @@ +scipaper_targets: + id: scipaper_targets.test.v1 + metrics: [accuracy] + description: Test the model's ability to retrieve protein/cell line targets from literature. + +scipaper_targets.test.v1: + class: evals.elsuite.modelgraded.rag_classify:RAGModelBasedClassify + args: + samples_jsonl: 03_scipaper_targets/samples.jsonl + modelgraded_spec: closedqa + modelgraded_spec_args: + criteria: "conciseness: Does the answer has the same biological meaning as the content?" From 3f3077257d94833ad0a0ab5d9fb460765b8fbfb6 Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Wed, 10 Jan 2024 15:47:09 +0800 Subject: [PATCH 12/22] add Chemistry evalset --- evals/registry/eval_sets/chemistry.yaml | 11 +++++++++++ evals/registry/evals/00_scipaper_affinity.yaml | 2 +- evals/registry/evals/01_scipaper_hasmol.yaml | 8 ++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 evals/registry/eval_sets/chemistry.yaml create mode 100644 evals/registry/evals/01_scipaper_hasmol.yaml diff --git a/evals/registry/eval_sets/chemistry.yaml b/evals/registry/eval_sets/chemistry.yaml new file mode 100644 index 0000000000..421901e504 --- /dev/null +++ b/evals/registry/eval_sets/chemistry.yaml @@ -0,0 +1,11 @@ +chemistry: + evals: + - scipaper_affinity + - scipaper_tag2mol + - scipaper_hasmol + - markush2mol + - scipaper_targets + - abstract2title + - research-question-extraction + - balance-chemical-equation + - medmcqa \ No newline at end of file diff --git a/evals/registry/evals/00_scipaper_affinity.yaml b/evals/registry/evals/00_scipaper_affinity.yaml index 236dde23ca..e02548dcc3 100644 --- a/evals/registry/evals/00_scipaper_affinity.yaml +++ b/evals/registry/evals/00_scipaper_affinity.yaml @@ -1,5 +1,5 @@ scipaper_affinity: - id: scipaper_affinity.val.ab-v1 + id: scipaper_affinity.val.csv metrics: [accuracy] scipaper_affinity.val.json: class: evals.elsuite.rag_table_extract:TableExtract diff --git a/evals/registry/evals/01_scipaper_hasmol.yaml b/evals/registry/evals/01_scipaper_hasmol.yaml new file mode 100644 index 0000000000..dfb0f7445a --- /dev/null +++ b/evals/registry/evals/01_scipaper_hasmol.yaml @@ -0,0 +1,8 @@ +scipaper_hasmol: + id: scipaper_hasmol.dev.v0 + metrics: [accuracy] + +scipaper_hasmol.dev.v0: + class: evals.elsuite.rag_match:RAGMatch + args: + samples_jsonl: 01_scipaper_has2mol/samples.jsonl \ No newline at end of file From b861b7d303c801fc675a27fd24ed769f21859762 Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Mon, 15 Jan 2024 11:15:34 +0800 Subject: [PATCH 13/22] bugfixes --- evals/completion_fns/uni_finder.py | 34 ++++++++++++++------ evals/elsuite/rag_table_extract.py | 2 +- evals/registry/evals/01_scipaper_hasmol.yaml | 2 +- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/evals/completion_fns/uni_finder.py b/evals/completion_fns/uni_finder.py index b499a39ea9..14281110f9 100644 --- a/evals/completion_fns/uni_finder.py +++ b/evals/completion_fns/uni_finder.py @@ -1,8 +1,11 @@ """ Extending Completion Functions with Embeddings-based retrieval from a fetched dataset """ +import json import os import time +from pathlib import Path + import requests from typing import Any, Optional, Union @@ -31,6 +34,7 @@ def __init__( api_base: Optional[str] = None, api_key: Optional[str] = None, n_ctx: Optional[int] = None, + cache_dir: Optional[str] = "~/.uni_finder/knowledge_base.json", extra_options: Optional[dict] = {}, **kwargs ): @@ -40,6 +44,10 @@ def __init__( self.api_key = api_key or os.environ.get("UNIFINDER_API_KEY") self.n_ctx = n_ctx self.extra_options = extra_options + self.cache_dir = cache_dir + Path(self.cache_dir).parent.mkdir(parents=True, exist_ok=True) + if not Path(self.cache_dir).exists(): + json.dump({}, open(self.cache_dir, "w")) def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> UniFinderCompletionResult: """ @@ -50,15 +58,23 @@ def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> UniFinderCo pdf_token = [] if "file_name" in kwargs: - url = f"{self.api_base}/api/external/upload_pdf" - pdf_parse_mode = 'fast' # or 'precise', 指定使用的pdf解析版本 - files = {'file': open(kwargs["file_name"], 'rb')} - data = { - 'pdf_parse_mode': pdf_parse_mode, - 'api_key': self.api_key - } - response = requests.post(url, data=data, files=files).json() - pdf_id = response['pdf_token'] # 获得pdf的id,表示上传成功,后续可以使用这个id来指定pdf + cache = json.load(open(self.cache_dir, 'r+')) + + if kwargs["file_name"] not in cache: + url = f"{self.api_base}/api/external/upload_pdf" + pdf_parse_mode = 'fast' # or 'precise', 指定使用的pdf解析版本 + files = {'file': open(kwargs["file_name"], 'rb')} + data = { + 'pdf_parse_mode': pdf_parse_mode, + 'api_key': self.api_key + } + response = requests.post(url, data=data, files=files).json() + pdf_id = response['pdf_token'] # 获得pdf的id,表示上传成功,后续可以使用这个id来指定pdf + + cache[kwargs["file_name"]] = pdf_id + json.dump(cache, open(self.cache_dir, "w")) + else: + pdf_id = cache[kwargs["file_name"]] print("############# pdf_id ##############", pdf_id) pdf_token.append(pdf_id) diff --git a/evals/elsuite/rag_table_extract.py b/evals/elsuite/rag_table_extract.py index 3099d08b50..8ecb0b86cb 100644 --- a/evals/elsuite/rag_table_extract.py +++ b/evals/elsuite/rag_table_extract.py @@ -186,7 +186,7 @@ def eval_sample(self, sample, rng): match_all = True for field in sample.compare_fields: - if type(field) == tuple: + if type(field) == tuple and len(field) > 1: field = (field[0], fuzzy_normalize(field[1])) match_field = field in table.columns and field in correct_answer.columns match_all = match_all and match_field diff --git a/evals/registry/evals/01_scipaper_hasmol.yaml b/evals/registry/evals/01_scipaper_hasmol.yaml index dfb0f7445a..c176f7dc06 100644 --- a/evals/registry/evals/01_scipaper_hasmol.yaml +++ b/evals/registry/evals/01_scipaper_hasmol.yaml @@ -5,4 +5,4 @@ scipaper_hasmol: scipaper_hasmol.dev.v0: class: evals.elsuite.rag_match:RAGMatch args: - samples_jsonl: 01_scipaper_has2mol/samples.jsonl \ No newline at end of file + samples_jsonl: 01_scipaper_hasmol/samples.jsonl \ No newline at end of file From e52d776b5a45c2397f4d3570ec48e894e0edee83 Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Thu, 18 Jan 2024 10:55:33 +0800 Subject: [PATCH 14/22] table comparison with self-defined index --- evals/completion_fns/uni_finder.py | 8 ++++++-- evals/elsuite/rag_match.py | 7 ++++--- evals/elsuite/rag_table_extract.py | 22 ++++++++++++++++------ 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/evals/completion_fns/uni_finder.py b/evals/completion_fns/uni_finder.py index 14281110f9..cc1ffc2506 100644 --- a/evals/completion_fns/uni_finder.py +++ b/evals/completion_fns/uni_finder.py @@ -89,7 +89,11 @@ def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> UniFinderCo "query": prompt, 'api_key': self.api_key } - response = requests.post(url, json=payload).json() - answer = response['answer'] + response = requests.post(url, json=payload) + try: + answer = response.json()['answer'] + except: + print(response.text) + answer = response.text record_sampling(prompt=prompt, sampled=answer) return UniFinderCompletionResult(answer) diff --git a/evals/elsuite/rag_match.py b/evals/elsuite/rag_match.py index 1678582730..e541e520e8 100644 --- a/evals/elsuite/rag_match.py +++ b/evals/elsuite/rag_match.py @@ -47,15 +47,16 @@ def get_rag_dataset(samples_jsonl: str) -> list[dict]: # 上传文件 bucket.put_object_from_file(oss_file, raw_sample[f"{ftype}file_name"]) print(f"文件 {oss_file} 已上传到 OSS。") - elif f"{ftype}file_link" in raw_sample: - local_file = raw_sample[f"{ftype}file_name"] if f"{ftype}file_name" in raw_sample else os.path.basename( - raw_sample[f"{ftype}file_link"]) + if f"{ftype}file_link" in raw_sample: + local_file = raw_sample[f"{ftype}file_name"] if f"{ftype}file_name" in raw_sample else \ + os.path.basename(raw_sample[f"{ftype}file_link"]) oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_link"]) if not os.path.exists(local_file): if bucket.object_exists(oss_file): # 从 OSS 下载文件 Path(local_file).parent.mkdir(parents=True, exist_ok=True) bucket.get_object_to_file(oss_file, local_file) + print(f"文件 {oss_file} 已下载到本地。") return raw_samples diff --git a/evals/elsuite/rag_table_extract.py b/evals/elsuite/rag_table_extract.py index 8ecb0b86cb..ec2447e926 100644 --- a/evals/elsuite/rag_table_extract.py +++ b/evals/elsuite/rag_table_extract.py @@ -53,6 +53,7 @@ class FileSample(BaseModel): answerfile_name: Optional[str] answerfile_link: Optional[str] compare_fields: List[Union[str, Tuple]] + index: Union[str, Tuple] = ("Compound", "") def fuzzy_compare(a: str, b: str) -> bool: @@ -66,7 +67,7 @@ def standardize_unit(s: str) -> str: """ mark = "" if re.search(r"[><=]", s) is None else re.search(r"[><=]", s).group() unit = s.rstrip()[-2:] - number = float(re.search(r"[0-9.\+\-]+", s).group()) + number = float(re.search(r"[\+\-]*[0-9.]+", s).group()) if unit in ["µM", "uM"]: unit = "nM" @@ -150,8 +151,10 @@ def eval_sample(self, sample, rng): ) sampled = result.get_completions()[0] - correct_answer = parse_table_multiindex(pd.read_csv(sample.answerfile_name, header=[0, 1]).astype(str)) - correct_answer = correct_answer.sort_values(by=("Compound", "")) + compare_fields_types = [type(x) for x in sample.compare_fields] + header_rows = [0, 1] if tuple in compare_fields_types else [0] + + correct_answer = parse_table_multiindex(pd.read_csv(sample.answerfile_name, header=header_rows).astype(str)) correct_answer.to_csv("temp.csv", index=False) correct_str = open("temp.csv", 'r').read() @@ -161,7 +164,7 @@ def eval_sample(self, sample, rng): code_content = re.sub(code_pattern, r"\1", code) table = pd.read_csv(StringIO(code_content)) if pd.isna(table.iloc[0, 0]): - table = pd.read_csv(StringIO(code_content), header=[0, 1]) + table = pd.read_csv(StringIO(code_content), header=header_rows) elif "json" in prompt: code = re.search(code_pattern, sampled).group() @@ -169,7 +172,7 @@ def eval_sample(self, sample, rng): table = pd.DataFrame(json.loads(code_content)) else: table = pd.DataFrame() - table = parse_table_multiindex(table).sort_values(by=("Compound", "")) + table = parse_table_multiindex(table) except: record_match( correct=False, @@ -184,10 +187,17 @@ def eval_sample(self, sample, rng): table.to_csv(answerfile_out, index=False) picked_str = open(answerfile_out, 'r').read() + comparison_df = pd.merge(table.set_index(sample.index, drop=False), + correct_answer.set_index(sample.index, drop=False), + how="right", left_index=True, right_index=True) + match_all = True for field in sample.compare_fields: if type(field) == tuple and len(field) > 1: field = (field[0], fuzzy_normalize(field[1])) + field_sample, field_correct = (f"{field[0]}_x", field[1]), (f"{field[0]}_y", field[1]) + else: + field_sample, field_correct = f"{field}_x", f"{field}_y" match_field = field in table.columns and field in correct_answer.columns match_all = match_all and match_field record_match( @@ -208,7 +218,7 @@ def eval_sample(self, sample, rng): jobtype="match_number" ) - for sample_value, correct_value in zip(table[field], correct_answer[field]): + for sample_value, correct_value in zip(comparison_df[field_sample], comparison_df[field_correct]): match_value = fuzzy_compare(str(sample_value), str(correct_value)) match_all = match_all and match_value record_match( From f1b82bcba57ee0cfee7e8efe0adbf0a4051e289d Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Thu, 18 Jan 2024 20:47:48 +0800 Subject: [PATCH 15/22] fix table extraction with detailed csv text processing and edit-distance comparison --- evals/completion_fns/uni_finder.py | 2 +- evals/elsuite/rag_table_extract.py | 56 ++++++++++++++----- .../data/00_scipaper_affinity/samples.jsonl | 4 +- 3 files changed, 44 insertions(+), 18 deletions(-) diff --git a/evals/completion_fns/uni_finder.py b/evals/completion_fns/uni_finder.py index cc1ffc2506..5ac24a5d81 100644 --- a/evals/completion_fns/uni_finder.py +++ b/evals/completion_fns/uni_finder.py @@ -34,7 +34,7 @@ def __init__( api_base: Optional[str] = None, api_key: Optional[str] = None, n_ctx: Optional[int] = None, - cache_dir: Optional[str] = "~/.uni_finder/knowledge_base.json", + cache_dir: Optional[str] = str(Path.home() / ".uni_finder/knowledge_base.json"), extra_options: Optional[dict] = {}, **kwargs ): diff --git a/evals/elsuite/rag_table_extract.py b/evals/elsuite/rag_table_extract.py index ec2447e926..ae68b6ee77 100644 --- a/evals/elsuite/rag_table_extract.py +++ b/evals/elsuite/rag_table_extract.py @@ -1,6 +1,8 @@ +import traceback from io import StringIO import json import re +from pathlib import Path from typing import List, Optional, Tuple, Union @@ -18,6 +20,17 @@ csv_pattern = r"```csv[\s\S]*?\n([\s\S]+?)\n```" +def parse_csv_text(csvtext: str) -> str: + lines = csvtext.strip().split("\n") + tuple_pattern = r"\((\"[\s\S]*?\"),(\"[\s\S]*?\")\)" + if re.search(tuple_pattern, lines[0]) is not None: + lines[0] = re.sub(tuple_pattern, r"(\1|\2)", lines[0]) + lines_clr = [re.sub(r"\"[\s\S]*?\"", "", line) for line in lines] + max_commas = max([line_clr.count(",") for line_clr in lines_clr]) + unified_lines = [line + ("," * (max_commas - line_clr.count(","))) for line, line_clr in zip(lines, lines_clr)] + return "\n".join(unified_lines) + + def parse_table_multiindex(table: pd.DataFrame) -> pd.DataFrame: """ Parse a table with multiindex columns. @@ -39,7 +52,8 @@ def parse_table_multiindex(table: pd.DataFrame) -> pd.DataFrame: d = pd.DataFrame(df.pop(col).tolist()) d.columns = pd.MultiIndex.from_tuples([(col, fuzzy_normalize(key)) for key in d.columns]) dfs.append(d) - df.columns = pd.MultiIndex.from_tuples([(col, "") for col in df.columns]) + df.columns = pd.MultiIndex.from_tuples([eval(col.replace("|", ",")) if (col[0] == "(" and col[-1] == ")") else + (col, "") for col in df.columns]) df = pd.concat([df] + dfs, axis=1) if df.columns.nlevels > 1: df.columns = pd.MultiIndex.from_tuples([(col, fuzzy_normalize(subcol)) for col, subcol in df.columns]) @@ -80,7 +94,7 @@ def standardize_unit(s: str) -> str: mark = "" return f"{mark}{number:.1f} {unit}" - unit_str = ["nM", "uM", "µM", "mM", "%", " %"] + unit_str = ["nM", "uM", "µM", "mM", "M", "%", " %"] nan_str = ["n/a", "nan", "na", "nd", "not determined", "not tested"] a = a.strip() b = b.strip() @@ -91,7 +105,8 @@ def standardize_unit(s: str) -> str: elif a.lower() in nan_str and b.lower() in nan_str: return True else: - return (a.lower() in b.lower()) or (b.lower() in a.lower()) + import Levenshtein + return (a.lower() in b.lower()) or (b.lower() in a.lower()) or Levenshtein.distance(a.lower(), b.lower()) / (len(a) + len(b)) < 0.1 def fuzzy_normalize(s): @@ -100,12 +115,12 @@ def fuzzy_normalize(s): else: """ 标准化字符串 """ # 定义需要移除的单位和符号 - units = ["µM", "µg/mL", "nM"] + units = ["µM", "µg/mL", "nM", "M"] for unit in units: s = s.replace(unit, "") # 定义特定关键字 - keywords = ["IC50", "EC50", "TC50", "GI50", "Ki", "Kd"] + keywords = ["pIC50", "IC50", "EC50", "TC50", "GI50", "Ki", "Kd", "Kb", "pKb"] # 移除非字母数字的字符,除了空格 s = re.sub(r'[^\w\s]', '', s) @@ -162,9 +177,11 @@ def eval_sample(self, sample, rng): if "csv" in prompt: code = re.search(code_pattern, sampled).group() code_content = re.sub(code_pattern, r"\1", code) - table = pd.read_csv(StringIO(code_content)) + code_content_processed = parse_csv_text(code_content) + # table = pd.read_csv(StringIO(code_content_processed), header=header_rows) + table = pd.read_csv(StringIO(code_content_processed)) if pd.isna(table.iloc[0, 0]): - table = pd.read_csv(StringIO(code_content), header=header_rows) + table = pd.read_csv(StringIO(code_content_processed), header=header_rows) elif "json" in prompt: code = re.search(code_pattern, sampled).group() @@ -173,7 +190,24 @@ def eval_sample(self, sample, rng): else: table = pd.DataFrame() table = parse_table_multiindex(table) + + if sample.index not in table.columns: + table.columns = [sample.index] + list(table.columns)[1:] + answerfile_out = sample.answerfile_name.replace(".csv", "_output.csv") + table.to_csv(answerfile_out, index=False) + picked_str = open(answerfile_out, 'r').read() + + comparison_df = pd.merge(table.set_index(sample.index, drop=False), + correct_answer.set_index(sample.index, drop=False), + how="right", left_index=True, right_index=True) except: + print(Path(sample.file_name).stem) + code = re.search(code_pattern, sampled).group() + code_content = re.sub(code_pattern, r"\1", code) + code_content_processed = parse_csv_text(code_content) + print(code_content) + print(code_content_processed) + traceback.print_exc() record_match( correct=False, expected=correct_str, @@ -183,14 +217,6 @@ def eval_sample(self, sample, rng): ) return - answerfile_out = sample.answerfile_name.replace(".csv", "_output.csv") - table.to_csv(answerfile_out, index=False) - picked_str = open(answerfile_out, 'r').read() - - comparison_df = pd.merge(table.set_index(sample.index, drop=False), - correct_answer.set_index(sample.index, drop=False), - how="right", left_index=True, right_index=True) - match_all = True for field in sample.compare_fields: if type(field) == tuple and len(field) > 1: diff --git a/evals/registry/data/00_scipaper_affinity/samples.jsonl b/evals/registry/data/00_scipaper_affinity/samples.jsonl index 5f093873e3..ff4161ddd4 100644 --- a/evals/registry/data/00_scipaper_affinity/samples.jsonl +++ b/evals/registry/data/00_scipaper_affinity/samples.jsonl @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:3f95fad1b8c6426a1acef3ae0190d4c81f81f5ebb0b8ed6fe3b39c683f77b7ee -size 5047 +oid sha256:cf9a87b0db43a8b4324950dc46cf51d18e9c817e541c6445e4aea266bb6b1ee9 +size 5281 From decf0f3febb28410af00a71999aef16a3ab3edcb Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Fri, 19 Jan 2024 08:49:22 +0800 Subject: [PATCH 16/22] fix match_field compare logic to edit-distance --- evals/completion_fns/uni_finder.py | 2 +- evals/elsuite/rag_table_extract.py | 31 +++++++++++++++++++++++------- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/evals/completion_fns/uni_finder.py b/evals/completion_fns/uni_finder.py index 5ac24a5d81..68ca25f891 100644 --- a/evals/completion_fns/uni_finder.py +++ b/evals/completion_fns/uni_finder.py @@ -89,7 +89,7 @@ def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> UniFinderCo "query": prompt, 'api_key': self.api_key } - response = requests.post(url, json=payload) + response = requests.post(url, json=payload, timeout=1200) try: answer = response.json()['answer'] except: diff --git a/evals/elsuite/rag_table_extract.py b/evals/elsuite/rag_table_extract.py index ae68b6ee77..c9b207f104 100644 --- a/evals/elsuite/rag_table_extract.py +++ b/evals/elsuite/rag_table_extract.py @@ -70,7 +70,7 @@ class FileSample(BaseModel): index: Union[str, Tuple] = ("Compound", "") -def fuzzy_compare(a: str, b: str) -> bool: +def fuzzy_compare(a: str, b: str) -> Union[bool, float]: """ Compare two strings with fuzzy matching. """ @@ -104,9 +104,11 @@ def standardize_unit(s: str) -> str: return a == b elif a.lower() in nan_str and b.lower() in nan_str: return True + elif (a.lower() in b.lower()) or (b.lower() in a.lower()): + return True else: import Levenshtein - return (a.lower() in b.lower()) or (b.lower() in a.lower()) or Levenshtein.distance(a.lower(), b.lower()) / (len(a) + len(b)) < 0.1 + return Levenshtein.distance(a.lower(), b.lower()) / (len(a) + len(b)) < 0.1 def fuzzy_normalize(s): @@ -196,10 +198,6 @@ def eval_sample(self, sample, rng): answerfile_out = sample.answerfile_name.replace(".csv", "_output.csv") table.to_csv(answerfile_out, index=False) picked_str = open(answerfile_out, 'r').read() - - comparison_df = pd.merge(table.set_index(sample.index, drop=False), - correct_answer.set_index(sample.index, drop=False), - how="right", left_index=True, right_index=True) except: print(Path(sample.file_name).stem) code = re.search(code_pattern, sampled).group() @@ -217,10 +215,29 @@ def eval_sample(self, sample, rng): ) return + renames = {} + for field in sample.compare_fields: + for i, sample_field in enumerate(table.columns): + field_query = field if type(field) != tuple else field[0] if field[1] == "" else field[1] + sample_field_query = sample_field if type(sample_field) != tuple else sample_field[0] if sample_field[1] == "" else sample_field[1] + if fuzzy_normalize(field_query) == "" or fuzzy_normalize(sample_field_query) == "": + continue + if fuzzy_compare(fuzzy_normalize(field_query), fuzzy_normalize(sample_field_query)) and \ + fuzzy_normalize(field_query).split()[-1] == fuzzy_normalize(sample_field_query).split()[-1]: + if sample_field not in renames and sample_field_query != field_query: + renames[sample_field_query] = field_query + break + if len(renames) > 0: + print("Find similar fields between answer and correct:", renames) + table.rename(columns=renames, inplace=True) + + comparison_df = pd.merge(table.set_index(sample.index, drop=False), + correct_answer.set_index(sample.index, drop=False), + how="right", left_index=True, right_index=True) + match_all = True for field in sample.compare_fields: if type(field) == tuple and len(field) > 1: - field = (field[0], fuzzy_normalize(field[1])) field_sample, field_correct = (f"{field[0]}_x", field[1]), (f"{field[0]}_y", field[1]) else: field_sample, field_correct = f"{field}_x", f"{field}_y" From 86f90c9792353ffc02f9aad279da53274759610c Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Fri, 19 Jan 2024 13:41:45 +0800 Subject: [PATCH 17/22] fixes on data and details for good scipaper_affinity performance --- evals/elsuite/rag_table_extract.py | 13 +++++++++---- .../data/00_scipaper_affinity/samples.jsonl | 4 ++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/evals/elsuite/rag_table_extract.py b/evals/elsuite/rag_table_extract.py index c9b207f104..19d58f7cd2 100644 --- a/evals/elsuite/rag_table_extract.py +++ b/evals/elsuite/rag_table_extract.py @@ -95,7 +95,7 @@ def standardize_unit(s: str) -> str: return f"{mark}{number:.1f} {unit}" unit_str = ["nM", "uM", "µM", "mM", "M", "%", " %"] - nan_str = ["n/a", "nan", "na", "nd", "not determined", "not tested"] + nan_str = ["n/a", "nan", "na", "n.a.", "nd", "not determined", "not tested", "inactive"] a = a.strip() b = b.strip() if (a[-2:] in unit_str or a[-1] in unit_str) and (b[-2:] in unit_str or b[-1] in unit_str): @@ -117,7 +117,7 @@ def fuzzy_normalize(s): else: """ 标准化字符串 """ # 定义需要移除的单位和符号 - units = ["µM", "µg/mL", "nM", "M"] + units = ["µM", "µg/mL", "nM"] for unit in units: s = s.replace(unit, "") @@ -125,7 +125,7 @@ def fuzzy_normalize(s): keywords = ["pIC50", "IC50", "EC50", "TC50", "GI50", "Ki", "Kd", "Kb", "pKb"] # 移除非字母数字的字符,除了空格 - s = re.sub(r'[^\w\s]', '', s) + # s = re.sub(r'[^\w\s]', '', s) # 分割字符串为单词列表 words = s.split() @@ -215,6 +215,7 @@ def eval_sample(self, sample, rng): ) return + # TODO: Use similarity and Bipartite matching to match fields renames = {} for field in sample.compare_fields: for i, sample_field in enumerate(table.columns): @@ -224,13 +225,17 @@ def eval_sample(self, sample, rng): continue if fuzzy_compare(fuzzy_normalize(field_query), fuzzy_normalize(sample_field_query)) and \ fuzzy_normalize(field_query).split()[-1] == fuzzy_normalize(sample_field_query).split()[-1]: - if sample_field not in renames and sample_field_query != field_query: + if sample_field not in renames.keys() and field_query not in renames.values(): renames[sample_field_query] = field_query break + renames = {key: value for key, value in renames.items() if key not in ["Compound", "Name", "SMILES"]} if len(renames) > 0: print("Find similar fields between answer and correct:", renames) table.rename(columns=renames, inplace=True) + print(table) + table[sample.index] = table[sample.index].astype(str) + correct_answer[sample.index] = correct_answer[sample.index].astype(str) comparison_df = pd.merge(table.set_index(sample.index, drop=False), correct_answer.set_index(sample.index, drop=False), how="right", left_index=True, right_index=True) diff --git a/evals/registry/data/00_scipaper_affinity/samples.jsonl b/evals/registry/data/00_scipaper_affinity/samples.jsonl index ff4161ddd4..8988938c97 100644 --- a/evals/registry/data/00_scipaper_affinity/samples.jsonl +++ b/evals/registry/data/00_scipaper_affinity/samples.jsonl @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:cf9a87b0db43a8b4324950dc46cf51d18e9c817e541c6445e4aea266bb6b1ee9 -size 5281 +oid sha256:62f7cbd7ee9d4b0b4f7fdaa48a5d3033bb11900578cc3b3516ec4f1b052b0dc9 +size 5391 From 9860058c047d75f6a9caf3dd73a9cc58e96cda3c Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Fri, 19 Jan 2024 15:16:37 +0800 Subject: [PATCH 18/22] update uni_finder api with pdf_parse_mode --- evals/completion_fns/uni_finder.py | 19 ++++++++++++------- evals/registry/completion_fns/uni_finder.yaml | 18 ++++++++++++++++-- .../data/00_scipaper_affinity/samples.jsonl | 4 ++-- 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/evals/completion_fns/uni_finder.py b/evals/completion_fns/uni_finder.py index 68ca25f891..461bfd5065 100644 --- a/evals/completion_fns/uni_finder.py +++ b/evals/completion_fns/uni_finder.py @@ -35,6 +35,7 @@ def __init__( api_key: Optional[str] = None, n_ctx: Optional[int] = None, cache_dir: Optional[str] = str(Path.home() / ".uni_finder/knowledge_base.json"), + pdf_parse_mode: Optional[str] = 'fast', # or 'precise', 指定使用的pdf解析版本 extra_options: Optional[dict] = {}, **kwargs ): @@ -45,6 +46,7 @@ def __init__( self.n_ctx = n_ctx self.extra_options = extra_options self.cache_dir = cache_dir + self.pdf_parse_mode = pdf_parse_mode Path(self.cache_dir).parent.mkdir(parents=True, exist_ok=True) if not Path(self.cache_dir).exists(): json.dump({}, open(self.cache_dir, "w")) @@ -60,21 +62,24 @@ def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> UniFinderCo if "file_name" in kwargs: cache = json.load(open(self.cache_dir, 'r+')) - if kwargs["file_name"] not in cache: + if cache.get(kwargs["file_name"], {}).get(self.pdf_parse_mode, None) is None: url = f"{self.api_base}/api/external/upload_pdf" - pdf_parse_mode = 'fast' # or 'precise', 指定使用的pdf解析版本 files = {'file': open(kwargs["file_name"], 'rb')} data = { - 'pdf_parse_mode': pdf_parse_mode, + 'pdf_parse_mode': self.pdf_parse_mode, 'api_key': self.api_key } - response = requests.post(url, data=data, files=files).json() - pdf_id = response['pdf_token'] # 获得pdf的id,表示上传成功,后续可以使用这个id来指定pdf + response = requests.post(url, data=data, files=files) + print(response.text) + pdf_id = response.json()['pdf_token'] # 获得pdf的id,表示上传成功,后续可以使用这个id来指定pdf - cache[kwargs["file_name"]] = pdf_id + if kwargs["file_name"] not in cache: + cache[kwargs["file_name"]] = {self.pdf_parse_mode: pdf_id} + else: + cache[kwargs["file_name"]][self.pdf_parse_mode] = pdf_id json.dump(cache, open(self.cache_dir, "w")) else: - pdf_id = cache[kwargs["file_name"]] + pdf_id = cache[kwargs["file_name"]][self.pdf_parse_mode] print("############# pdf_id ##############", pdf_id) pdf_token.append(pdf_id) diff --git a/evals/registry/completion_fns/uni_finder.yaml b/evals/registry/completion_fns/uni_finder.yaml index b31c691445..ae2c7b778e 100644 --- a/evals/registry/completion_fns/uni_finder.yaml +++ b/evals/registry/completion_fns/uni_finder.yaml @@ -1,9 +1,23 @@ -uni_finder/gpt-3.5-turbo: +uni_finder/fast/gpt-3.5-turbo: class: evals.completion_fns.uni_finder:UniFinderCompletionFn args: + pdf_parse_mode: fast model: gpt35 -uni_finder/gpt-4-all: +uni_finder/precise/gpt-3.5-turbo: class: evals.completion_fns.uni_finder:UniFinderCompletionFn args: + pdf_parse_mode: precise + model: gpt35 + +uni_finder/fast/gpt-4-all: + class: evals.completion_fns.uni_finder:UniFinderCompletionFn + args: + pdf_parse_mode: fast model: gpt4 + +uni_finder/precise/gpt-4-all: + class: evals.completion_fns.uni_finder:UniFinderCompletionFn + args: + pdf_parse_mode: precise + model: gpt4 \ No newline at end of file diff --git a/evals/registry/data/00_scipaper_affinity/samples.jsonl b/evals/registry/data/00_scipaper_affinity/samples.jsonl index 8988938c97..ed8c716f73 100644 --- a/evals/registry/data/00_scipaper_affinity/samples.jsonl +++ b/evals/registry/data/00_scipaper_affinity/samples.jsonl @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:62f7cbd7ee9d4b0b4f7fdaa48a5d3033bb11900578cc3b3516ec4f1b052b0dc9 -size 5391 +oid sha256:2e7193711a5f342aa26e16fe275e108eefc8fdc1ef6f4f6545dcb8f901132f2d +size 4916 From f846a1af3c7ded0dec28b3c362a5b210a3c70bc2 Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Tue, 23 Jan 2024 14:55:08 +0800 Subject: [PATCH 19/22] update Zhishu completion_fn with common chat (no file_link) support --- evals/completion_fns/zhishu.py | 14 +++++++++----- evals/registry/completion_fns/zhishu.yaml | 4 ++-- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/evals/completion_fns/zhishu.py b/evals/completion_fns/zhishu.py index 41e1ff97e0..84e2dab60a 100644 --- a/evals/completion_fns/zhishu.py +++ b/evals/completion_fns/zhishu.py @@ -47,7 +47,7 @@ def get_completions(self) -> list[str]: if self.raw_data: for choice in self.raw_data.choices: if choice.message.content is not None: - completions.append(choice.text) + completions.append(choice.message.content) return completions @@ -86,12 +86,16 @@ def __call__( headers = { "content-type": "application/json" } + + messages = [ + {"role": "system", "content": self.instructions}, + {"role": "user", "content": f"{kwargs['file_link']} {prompt}"} + ] if "file_link" in kwargs else prompt if isinstance(prompt, list) else [{"role": "user", "content": prompt}] + payload = { "model": self.model, - "messages": [ - {"role": "system", "content": self.instructions}, - {"role": "user", "content": f"{kwargs['file_link']} {prompt}"} - ] + "token": self.api_key or os.environ['ZHISHU_API_KEY'], + "messages": messages } result = request_with_timeout(requests.post, url, json=payload, headers=headers) diff --git a/evals/registry/completion_fns/zhishu.yaml b/evals/registry/completion_fns/zhishu.yaml index 0a5d712b79..936c759f96 100644 --- a/evals/registry/completion_fns/zhishu.yaml +++ b/evals/registry/completion_fns/zhishu.yaml @@ -3,7 +3,7 @@ zhishu/gpt-3.5-turbo: args: model: gpt-3.5-turbo -zhishu/gpt-4.0-turbo: +zhishu/gpt-4-all: class: evals.completion_fns.zhishu:ZhishuCompletionFn args: - model: gpt-4.0-turbo + model: gpt-4-all From 3a4a643cb145dd440cded6af956d963aeb9eb9ef Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Tue, 23 Jan 2024 15:05:19 +0800 Subject: [PATCH 20/22] split test sets into general_chemistry and drug_discovery --- evals/registry/eval_sets/chemistry.yaml | 8 ++------ evals/registry/eval_sets/chemistry_drug.yaml | 8 ++++++++ 2 files changed, 10 insertions(+), 6 deletions(-) create mode 100644 evals/registry/eval_sets/chemistry_drug.yaml diff --git a/evals/registry/eval_sets/chemistry.yaml b/evals/registry/eval_sets/chemistry.yaml index 421901e504..ad417da139 100644 --- a/evals/registry/eval_sets/chemistry.yaml +++ b/evals/registry/eval_sets/chemistry.yaml @@ -1,11 +1,7 @@ chemistry: evals: - - scipaper_affinity - - scipaper_tag2mol - - scipaper_hasmol - - markush2mol - - scipaper_targets - abstract2title - research-question-extraction - balance-chemical-equation - - medmcqa \ No newline at end of file + - mmlu-college-chemistry + - mmlu-high-school-chemistry \ No newline at end of file diff --git a/evals/registry/eval_sets/chemistry_drug.yaml b/evals/registry/eval_sets/chemistry_drug.yaml new file mode 100644 index 0000000000..f152a92231 --- /dev/null +++ b/evals/registry/eval_sets/chemistry_drug.yaml @@ -0,0 +1,8 @@ +chemistry_drug: + evals: + - scipaper_affinity + - scipaper_tag2mol + - scipaper_hasmol + - markush2mol + - scipaper_targets + - medmcqa \ No newline at end of file From e6dece513618b10ff27297f804f3ae4ab6b89a40 Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Thu, 25 Jan 2024 12:04:59 +0800 Subject: [PATCH 21/22] fix Zhishu for mocked GPT-4 --- evals/completion_fns/uni_finder.py | 1 - evals/completion_fns/zhishu.py | 12 ++++++---- evals/elsuite/rag_table_extract.py | 28 ++++++++++++++++------- evals/registry/completion_fns/zhishu.yaml | 6 +++-- 4 files changed, 32 insertions(+), 15 deletions(-) diff --git a/evals/completion_fns/uni_finder.py b/evals/completion_fns/uni_finder.py index 461bfd5065..5a6fa04c9e 100644 --- a/evals/completion_fns/uni_finder.py +++ b/evals/completion_fns/uni_finder.py @@ -70,7 +70,6 @@ def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> UniFinderCo 'api_key': self.api_key } response = requests.post(url, data=data, files=files) - print(response.text) pdf_id = response.json()['pdf_token'] # 获得pdf的id,表示上传成功,后续可以使用这个id来指定pdf if kwargs["file_name"] not in cache: diff --git a/evals/completion_fns/zhishu.py b/evals/completion_fns/zhishu.py index 84e2dab60a..5677c67f54 100644 --- a/evals/completion_fns/zhishu.py +++ b/evals/completion_fns/zhishu.py @@ -59,6 +59,7 @@ def __init__( api_base: Optional[str] = None, api_key: Optional[str] = None, n_ctx: Optional[int] = None, + all_tools: Optional[bool] = False, extra_options: Optional[dict] = {}, **kwargs, ): @@ -67,6 +68,7 @@ def __init__( self.api_base = api_base self.api_key = api_key self.n_ctx = n_ctx + self.all_tools = all_tools self.extra_options = extra_options def __call__( @@ -87,19 +89,21 @@ def __call__( "content-type": "application/json" } - messages = [ - {"role": "system", "content": self.instructions}, + basic_message = [{"role": "system", "content": self.instructions}] if self.all_tools else [] + + messages = basic_message + [ {"role": "user", "content": f"{kwargs['file_link']} {prompt}"} ] if "file_link" in kwargs else prompt if isinstance(prompt, list) else [{"role": "user", "content": prompt}] payload = { "model": self.model, - "token": self.api_key or os.environ['ZHISHU_API_KEY'], "messages": messages } - result = request_with_timeout(requests.post, url, json=payload, headers=headers) + # result = request_with_timeout(requests.post, url, json=payload, headers=headers) + result = requests.post(url, json=payload, headers=headers) result = ZhishuCompletionResult(raw_data=result.json(), prompt=prompt) + print(result.get_completions()[0].replace("\\n", "\n")) record_sampling(prompt=result.prompt, sampled=result.get_completions()) return result diff --git a/evals/elsuite/rag_table_extract.py b/evals/elsuite/rag_table_extract.py index 19d58f7cd2..281977ef61 100644 --- a/evals/elsuite/rag_table_extract.py +++ b/evals/elsuite/rag_table_extract.py @@ -1,3 +1,4 @@ +import os import traceback from io import StringIO import json @@ -8,6 +9,7 @@ import pandas as pd from pydantic import BaseModel +import uuid import evals import evals.metrics @@ -15,9 +17,11 @@ from evals.elsuite.rag_match import get_rag_dataset from evals.record import RecorderBase, record_match + code_pattern = r"```[\s\S]*?\n([\s\S]+?)\n```" json_pattern = r"```json[\s\S]*?\n([\s\S]+?)\n```" csv_pattern = r"```csv[\s\S]*?\n([\s\S]+?)\n```" +outlink_pattern = r"\[Download[a-zA-Z0-9 ]+?\]\((https://[a-zA-Z0-9_. /]+?)\)" def parse_csv_text(csvtext: str) -> str: @@ -155,10 +159,9 @@ def __init__( def eval_sample(self, sample, rng): assert isinstance(sample, FileSample) - prompt = ( + prompt = \ self.instructions - + f"\nThe fields should at least contain {sample.compare_fields}" - ) + # + f"\nThe fields should at least contain {sample.compare_fields}" result = self.completion_fn( prompt=prompt, temperature=0.0, @@ -176,9 +179,18 @@ def eval_sample(self, sample, rng): correct_str = open("temp.csv", 'r').read() try: - if "csv" in prompt: - code = re.search(code_pattern, sampled).group() - code_content = re.sub(code_pattern, r"\1", code) + if re.search(outlink_pattern, sampled) is not None: + code = re.search(outlink_pattern, sampled).group() + link = re.sub(outlink_pattern, r"\1", code) + + fname = f"/tmp/LLMEvals_{uuid.uuid4()}.csv" + os.system(f"wget {link} -O {fname}") + table = pd.read_csv(fname) + if pd.isna(table.iloc[0, 0]): + table = pd.read_csv(fname, header=header_rows) + elif "csv" in prompt: + code = re.search(csv_pattern, sampled).group() + code_content = re.sub(csv_pattern, r"\1", code) code_content_processed = parse_csv_text(code_content) # table = pd.read_csv(StringIO(code_content_processed), header=header_rows) table = pd.read_csv(StringIO(code_content_processed)) @@ -186,8 +198,8 @@ def eval_sample(self, sample, rng): table = pd.read_csv(StringIO(code_content_processed), header=header_rows) elif "json" in prompt: - code = re.search(code_pattern, sampled).group() - code_content = re.sub(code_pattern, r"\1", code).replace("\"", "") + code = re.search(json_pattern, sampled).group() + code_content = re.sub(json_pattern, r"\1", code).replace("\"", "") table = pd.DataFrame(json.loads(code_content)) else: table = pd.DataFrame() diff --git a/evals/registry/completion_fns/zhishu.yaml b/evals/registry/completion_fns/zhishu.yaml index 936c759f96..13c281f84a 100644 --- a/evals/registry/completion_fns/zhishu.yaml +++ b/evals/registry/completion_fns/zhishu.yaml @@ -1,9 +1,11 @@ -zhishu/gpt-3.5-turbo: +zhishu/gpt-4: class: evals.completion_fns.zhishu:ZhishuCompletionFn args: - model: gpt-3.5-turbo + model: gpt-4-all + all_tools: False zhishu/gpt-4-all: class: evals.completion_fns.zhishu:ZhishuCompletionFn args: model: gpt-4-all + all_tools: True From 21cef0c095eb101573c06fa9b4b391faf8980e0f Mon Sep 17 00:00:00 2001 From: TablewareBox <1700011741@pku.edu.cn> Date: Fri, 26 Jan 2024 11:56:48 +0800 Subject: [PATCH 22/22] move --mlops option into llmreport entrypoint --- evals/cli/llmreport.py | 96 ++++++++++++++++++++++++++++++++++++++++++ evals/cli/oaieval.py | 43 ------------------- pyproject.toml | 1 + 3 files changed, 97 insertions(+), 43 deletions(-) create mode 100644 evals/cli/llmreport.py diff --git a/evals/cli/llmreport.py b/evals/cli/llmreport.py new file mode 100644 index 0000000000..15dd95471a --- /dev/null +++ b/evals/cli/llmreport.py @@ -0,0 +1,96 @@ +import argparse +import json +import pickle +import re +import glob +from io import StringIO +from pathlib import Path + +import pandas as pd +import matplotlib.pyplot as plt + + +def main() -> None: + parser = argparse.ArgumentParser(description="Report evals results") + parser.add_argument("run_id", type=str, nargs="+", help="Eval Run id") + parser.add_argument("--mlops", type=str, default=None) + parser.add_argument("--name", type=str, default="LLM_Eval") + + args = parser.parse_args() + + logfiles = [] + for run_id in args.run_id: + logfiles += glob.glob(f"/tmp/evallogs/{run_id}*/**", recursive=True) + logfiles = sorted([f for f in logfiles if Path(f).suffix == ".jsonl"]) + logger_data = {} + table_collection = [] + qa_collection = [] + + for logfile in logfiles: + with open(logfile, "r") as f: + events_df = pd.read_json(f, lines=True) + if not "final_report" in events_df.columns: + continue + final_report = events_df["final_report"].dropna().iloc[0] + + print(events_df) + run_config = events_df.loc[0, "spec"] + evalname = run_config["base_eval"] + model = run_config["completion_fns"][0].replace("/", ".") + matches_df = events_df[events_df["type"] == "match"].reset_index(drop=True) + matches_df = matches_df.join(pd.json_normalize(matches_df.data)) + + qa_collection.append({"eval": evalname, "model": model, **final_report}) + + if "file_name" in matches_df.columns: + matches_df["doi"] = [re.sub("__([0-9]+)__", r"(\1)", Path(f).stem).replace("_", "/") for f in matches_df["file_name"]] + + # TODO: compare on different completion_functions + if "jobtype" in matches_df.columns: + # Table extract tasks + accuracy_by_type_and_file = matches_df.groupby(["jobtype", "doi"])['correct'].mean().reset_index() + accuracy_by_type_and_file["model"] = model + table_collection.append(accuracy_by_type_and_file) + + accuracy_by_type = matches_df.groupby(["jobtype"])['correct'].mean().to_dict() + print(accuracy_by_type_and_file) + + logger_data = {**logger_data, **{f"Accuracy_{key}/model:{model}": value for key, value in accuracy_by_type.items()}} + + for doi, df in matches_df.groupby("doi"): + print(df) + logger_data[f"{doi.replace('/', '_')}/model:{model},context:match"] = df[df["jobtype"] != "match_all"][["correct", "expected", "picked", "jobtype"]] + match_all_data = df[df["jobtype"] == "match_all"].iloc[0, :] + logger_data[f"{doi.replace('/', '_')}/context:truth"] = pd.read_csv(StringIO(match_all_data["expected"]), header=[0, 1]) + logger_data[f"{doi.replace('/', '_')}/model:{model},context:extract"] = pd.read_csv(StringIO(match_all_data["picked"]), header=[0, 1]) \ + if df["jobtype"].iloc[0] != "match_all" else match_all_data["picked"] + else: + # Regular tasks + pass + + if len(table_collection) > 0: + accuracy_by_model_type_and_file = pd.concat(table_collection) + metrics_by_eval = pd.DataFrame(qa_collection) + accuracies = metrics_by_eval[metrics_by_eval["accuracy"] >= 0] + scores = metrics_by_eval[metrics_by_eval["score"] >= 0] + + if args.mlops: + import plotly.express as px + logger_data["TableExtraction"] = px.box(accuracy_by_model_type_and_file, + x="jobtype", y="correct", color="model", + title="Accuracy by jobtype and model") + logger_data["QA_accuracy"] = px.bar(accuracies, x="eval", y="accuracy", color="model", + title="Accuracy by eval and model") + logger_data["QA_score"] = px.bar(scores, x="eval", y="accuracy", color="model", + title="Accuracy by eval and model") + if args.mlops: + config_logger = json.load(open(args.mlops, 'r')) + if "name" not in config_logger.keys(): + config_logger["name"] = args.name + if "dp_mlops" in config_logger: + from evals.reporters.DPTracking import DPTrackingReporter + DPTrackingReporter.report_run(config_logger, {}, logger_data, step=0) + + +if __name__ == "__main__": + main() diff --git a/evals/cli/oaieval.py b/evals/cli/oaieval.py index 0841672fb4..b8adfcba5a 100644 --- a/evals/cli/oaieval.py +++ b/evals/cli/oaieval.py @@ -53,7 +53,6 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument( "--log_to_file", type=str, default=None, help="Log to a file instead of stdout" ) - parser.add_argument("--mlops", type=str, default=None) parser.add_argument( "--registry_path", type=str, @@ -112,7 +111,6 @@ class OaiEvalArguments(argparse.Namespace): user: str record_path: Optional[str] log_to_file: Optional[str] - mlops: Optional[str] registry_path: list[str] debug: bool local_run: bool @@ -237,47 +235,6 @@ def to_number(x: str) -> Union[int, float, str]: for key, value in result.items(): logger.info(f"{key}: {value}") - if args.mlops: - import pandas as pd - import plotly.express as px - - recorder.flush_events() - with open(record_path, "r") as f: - events_df = pd.read_json(f, lines=True) - - print(events_df) - run_config = events_df.loc[0, "spec"] - matches_df = events_df[events_df["type"] == "match"].reset_index(drop=True) - matches_df = matches_df.join(pd.json_normalize(matches_df.data)) - - matches_df["doi"] = [re.sub("__([0-9]+)__", "(\1)", Path(f).stem).replace("_", "/") for f in matches_df["file_name"]] - - # TODO: compare on different completion_functions - accuracy_by_type_and_file = matches_df.groupby(["jobtype", "doi"])['correct'].mean().reset_index() - accuracy_by_type = matches_df.groupby(["jobtype"])['correct'].mean().to_dict() - - print(accuracy_by_type_and_file) - - logger_data = { - **accuracy_by_type, - "Accuracy": px.box(accuracy_by_type_and_file, x="jobtype", y="correct", color="jobtype", title="Accuracy by jobtype and model"), - } - - for doi, df in matches_df.groupby("doi"): - logger_data[f"{doi.replace('/', '_')}/context:match"] = df[df["jobtype"] != "match_all"][["correct", "expected", "picked", "jobtype"]] - match_all_data = df[df["jobtype"] == "match_all"].iloc[0, :] - logger_data[f"{doi.replace('/', '_')}/context:truth"] = pd.read_csv(StringIO(match_all_data["expected"]), header=[0, 1]) - logger_data[f"{doi.replace('/', '_')}/context:extract"] = pd.read_csv(StringIO(match_all_data["picked"]), header=[0, 1]) \ - if df["jobtype"].iloc[0] != "match_all" else match_all_data["picked"] - pickle.dump(logger_data, open("logger_data.pkl", "wb")) - - config_logger = json.load(open(args.mlops, 'r')) - if "name" not in config_logger.keys(): - config_logger["name"] = f"{run_spec.run_id}_{args.completion_fn}_{args.eval}" - if "dp_mlops" in config_logger: - from evals.reporters.DPTracking import DPTrackingReporter - DPTrackingReporter.report_run(config_logger, run_config, logger_data, step=0) - return run_spec.run_id diff --git a/pyproject.toml b/pyproject.toml index b6eff11e67..05a5be64fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ formatters = [ [project.scripts] oaieval = "evals.cli.oaieval:main" oaievalset = "evals.cli.oaievalset:main" +llmreport = "evals.cli.llmreport:main" [tool.setuptools] packages = ["evals"]