-
Notifications
You must be signed in to change notification settings - Fork 3k
add mteb evaluation #8538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
add mteb evaluation #8538
Changes from 4 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
107 changes: 107 additions & 0 deletions
107
pipelines/examples/contrastive_training/evaluation/mteb/eval_mteb.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import argparse | ||
| import logging | ||
|
|
||
| from mteb import MTEB | ||
| from mteb_models import EncodeModel | ||
|
|
||
| from paddlenlp.transformers import AutoModel, AutoTokenizer | ||
|
|
||
|
|
||
| def get_model(peft_model_name, base_model_name): | ||
| if peft_model_name is not None: | ||
| raise NotImplementedError("PEFT model is not supported yet") | ||
| else: | ||
| base_model = AutoModel.from_pretrained(base_model_name) | ||
| return base_model | ||
|
|
||
|
|
||
| def get_args(): | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--base_model_name_or_path", default="bge-large-en-v1.5", type=str) | ||
| parser.add_argument("--peft_model_name_or_path", default=None, type=str) | ||
| parser.add_argument("--output_folder", default="tmp", type=str) | ||
|
|
||
| parser.add_argument("--task_name", default="SciFact", type=str) | ||
| parser.add_argument( | ||
| "--task_split", | ||
| default="test", | ||
| help='Note that some datasets do not have "test", they only have "dev"', | ||
| type=str, | ||
| ) | ||
|
|
||
| parser.add_argument("--query_instruction", default=None, help="add prefix instruction before query", type=str) | ||
| parser.add_argument( | ||
| "--document_instruction", default=None, help="add prefix instruction before document", type=str | ||
| ) | ||
|
|
||
| parser.add_argument("--pooling_method", default="last", help="choose in [mean, last, cls]", type=str) | ||
| parser.add_argument("--max_seq_length", default=512, type=int) | ||
| parser.add_argument("--eval_batch_size", default=1, type=int) | ||
|
|
||
| parser.add_argument("--pad_token", default="unk_token", help="unk_token, eos_token or pad_token", type=str) | ||
| parser.add_argument("--padding_side", default="left", help="right or left", type=str) | ||
| parser.add_argument("--add_bos_token", default=0, help="1 means add token", type=int) | ||
| parser.add_argument("--add_eos_token", default=1, help="1 means add token", type=int) | ||
|
|
||
| return parser.parse_args() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = get_args() | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| logging.basicConfig(level=logging.INFO) | ||
| logger.info("Args: {}".format(args)) | ||
|
|
||
| model = get_model(args.peft_model_name_or_path, args.base_model_name_or_path) | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(args.base_model_name_or_path) | ||
| assert hasattr(tokenizer, args.pad_token), f"Tokenizer does not have {args.pad_token} token" | ||
| token_dict = {"unk_token": tokenizer.unk_token, "eos_token": tokenizer.eos_token, "pad_token": tokenizer.pad_token} | ||
| tokenizer.pad_token = token_dict[args.pad_token] | ||
|
|
||
| assert args.padding_side in [ | ||
| "right", | ||
| "left", | ||
| ], f"padding_side should be either 'right' or 'left', but got {args.padding_side}" | ||
| assert not ( | ||
| args.padding_side == "left" and args.pooling_method == "cls" | ||
| ), "Padding 'left' is not supported for pooling method 'cls'" | ||
| tokenizer.padding_side = args.padding_side | ||
|
|
||
| assert args.add_bos_token in [0, 1], f"add_bos_token should be either 0 or 1, but got {args.add_bos_token}" | ||
| assert args.add_eos_token in [0, 1], f"add_eos_token should be either 0 or 1, but got {args.add_eos_token}" | ||
| tokenizer.add_bos_token = bool(args.add_bos_token) | ||
| tokenizer.add_eos_token = bool(args.add_eos_token) | ||
|
|
||
| encode_model = EncodeModel( | ||
| model=model, | ||
| tokenizer=tokenizer, | ||
| pooling_method=args.pooling_method, | ||
| query_instruction=args.query_instruction, | ||
| document_instruction=args.document_instruction, | ||
| eval_batch_size=args.eval_batch_size, | ||
| max_seq_length=args.max_seq_length, | ||
| ) | ||
|
|
||
| logger.info("Ready to eval") | ||
| evaluation = MTEB(tasks=[args.task_name]) | ||
| evaluation.run( | ||
| encode_model, | ||
| output_folder=f"{args.output_folder}/{args.task_name}/{args.pooling_method}", | ||
| eval_splits=[args.task_split], | ||
| ) |
127 changes: 127 additions & 0 deletions
127
pipelines/examples/contrastive_training/evaluation/mteb/mteb_models.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Dict, List, Union | ||
|
|
||
| import numpy as np | ||
| import paddle | ||
| from tqdm import tqdm | ||
|
|
||
|
|
||
| class EncodeModel: | ||
| def __init__( | ||
| self, | ||
| model, | ||
| tokenizer, | ||
| pooling_method: str = "last", | ||
| query_instruction: str = None, | ||
| document_instruction: str = None, | ||
| eval_batch_size: int = 64, | ||
| max_seq_length: int = 512, | ||
| ): | ||
| self.model = model | ||
| self.tokenizer = tokenizer | ||
| self.pooling_method = pooling_method | ||
| self.query_instruction = query_instruction | ||
| self.document_instruction = document_instruction | ||
| self.eval_batch_size = eval_batch_size | ||
| self.max_seq_length = max_seq_length | ||
|
|
||
| if paddle.device.is_compiled_with_cuda(): | ||
| self.device = paddle.device.set_device("gpu") | ||
| else: | ||
| self.device = paddle.device.set_device("cpu") | ||
| self.model = self.model.to(self.device) | ||
|
|
||
| num_gpus = paddle.device.cuda.device_count() | ||
| if num_gpus > 1: | ||
| raise NotImplementedError("Multi-GPU is not supported yet.") | ||
|
|
||
| def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray: | ||
| """ | ||
| This function will be used to encode queries for retrieval task | ||
| if there is a instruction for queries, we will add it to the query text | ||
| """ | ||
| if self.query_instruction is not None: | ||
| input_texts = [f"{self.query_instruction}{query}" for query in queries] | ||
| else: | ||
| input_texts = queries | ||
| return self.encode(input_texts) | ||
|
|
||
| def encode_corpus(self, corpus: List[Union[Dict[str, str], str]], **kwargs) -> np.ndarray: | ||
| """ | ||
| This function will be used to encode corpus for retrieval task | ||
| if there is a instruction for docs, we will add it to the doc text | ||
| """ | ||
| if isinstance(corpus[0], dict): | ||
| if self.document_instruction is not None: | ||
| input_texts = [ | ||
| "{}{} {}".format(self.document_instruction, doc.get("title", ""), doc["text"]).strip() | ||
| for doc in corpus | ||
| ] | ||
| else: | ||
| input_texts = ["{} {}".format(doc.get("title", ""), doc["text"]).strip() for doc in corpus] | ||
| else: | ||
| if self.document_instruction is not None: | ||
| input_texts = [f"{self.document_instruction}{doc}" for doc in corpus] | ||
| else: | ||
| input_texts = corpus | ||
| return self.encode(input_texts) | ||
|
|
||
| @paddle.no_grad() | ||
| def encode(self, sentences: List[str], **kwargs) -> np.ndarray: | ||
| self.model.eval() | ||
| all_embeddings = [] | ||
| for start_index in tqdm(range(0, len(sentences), self.eval_batch_size), desc="Batches"): | ||
| sentences_batch = sentences[start_index : start_index + self.eval_batch_size] | ||
|
|
||
| inputs = self.tokenizer( | ||
| sentences_batch, | ||
| padding=True, | ||
| truncation=True, | ||
| return_tensors="pd", | ||
| max_length=self.max_seq_length, | ||
| return_attention_mask=True, | ||
| ) | ||
| outputs = self.model( | ||
| input_ids=inputs.input_ids, | ||
| attention_mask=inputs.attention_mask, | ||
| return_dict=True, | ||
| output_hidden_states=True, | ||
| ) | ||
| last_hidden_state = outputs.hidden_states[-1] | ||
|
|
||
| if self.pooling_method == "last": | ||
| if self.tokenizer.padding_side == "right": | ||
| sequence_lengths = inputs.attention_mask.sum(axis=1) | ||
| last_token_indices = sequence_lengths - 1 | ||
| embeddings = last_hidden_state[paddle.arange(last_hidden_state.shape[0]), last_token_indices] | ||
| elif self.tokenizer.padding_side == "left": | ||
| embeddings = last_hidden_state[:, -1] | ||
| else: | ||
| raise NotImplementedError(f"Padding side {self.tokenizer.padding_side} not supported.") | ||
| elif self.pooling_method == "cls": | ||
| embeddings = last_hidden_state[:, 1] | ||
| elif self.pooling_method == "mean": | ||
| s = paddle.sum(last_hidden_state * inputs.attention_mask.unsqueeze(-1), axis=1) | ||
| d = inputs.attention_mask.sum(axis=1, keepdim=True) | ||
| embeddings = s / d | ||
| else: | ||
| raise NotImplementedError(f"Pooling method {self.pooling_method} not supported.") | ||
|
|
||
| embeddings = paddle.nn.functional.normalize(embeddings, p=2, axis=-1) | ||
|
|
||
| all_embeddings.append(embeddings.cpu().numpy().astype("float32")) | ||
|
|
||
| return np.concatenate(all_embeddings, axis=0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| paddlenlp>2.6.1 | ||
| datasets | ||
| torch==2.0.1 | ||
| mteb[beir] | ||
| mteb | ||
| beir | ||
| typer==0.9.0 |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参数的解释说明写到执行命令的后面,跟上面的写法保持一致。
