diff --git a/.github/workflows/sgl.yml b/.github/workflows/sgl.yml index 35bd87107d7..917b7a77909 100644 --- a/.github/workflows/sgl.yml +++ b/.github/workflows/sgl.yml @@ -86,3 +86,7 @@ jobs: run: | cd tests/workers/rollout pytest -s test_sglang_async_rollout_sf_tools.py + - name: Test the latest SGLang Rollout async with search tool + run: | + cd tests/workers/rollout + pytest -s test_sglang_async_rollout_search_tools.py \ No newline at end of file diff --git a/docs/sglang_multiturn/multiturn.rst b/docs/sglang_multiturn/multiturn.rst index 47891e9b3f1..bbb3e9bbc5d 100644 --- a/docs/sglang_multiturn/multiturn.rst +++ b/docs/sglang_multiturn/multiturn.rst @@ -1,8 +1,8 @@ Multi-turn Rollout Support -========================= +========================== Basic Configuration -~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~ To enable multi-turn rollout, make sure to configure the following fields in your rollout configuration: @@ -16,7 +16,7 @@ To enable multi-turn rollout, make sure to configure the following fields in you These configuration activates the sglang_async engine for multi-turn interaction during rollout. Custom Tool Configuration -~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~ For custom environment interaction tools, you can implement your own tools based on ``verl.tools.base_tool.BaseTool``. Then, specify your tool configurations in a YAML file: @@ -41,7 +41,7 @@ Finally, set the ``tools_config_file`` in your rollout config: This allows integration of customized tool behaviors during actor rollout steps. GSM8K Multi-turn Training Performance -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ See the training performance of multi-turn rollout on the GSM8K task HERE_. @@ -50,3 +50,11 @@ See the training performance of multi-turn rollout on the GSM8K task HERE_. .. _GSM8KTool_example_configuration: https://github.com/volcengine/verl/blob/main/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml .. _gsm8k_tool.py: https://github.com/volcengine/verl/blob/main/verl/tools/gsm8k_tool.py + +Search Tool Integration +~~~~~~~~~~~~~~~~~~~~~~~ + +.. toctree:: + :maxdepth: 1 + + search_tool_example \ No newline at end of file diff --git a/docs/sglang_multiturn/search_tool_example.rst b/docs/sglang_multiturn/search_tool_example.rst new file mode 100644 index 00000000000..4fac6ef3c1e --- /dev/null +++ b/docs/sglang_multiturn/search_tool_example.rst @@ -0,0 +1,261 @@ +======================= +Search Tool Integration +======================= +Introduction +------------ +- We have added a search tool calling function to Multi-Turn RL, enabling the model to initiate retrieval requests during Actor rollout and directly use retrieval results for training. **We support using a local dense retriever as the retrieval tool, as well as integrating with your own local retrieval engine.** + + + +Quick Reproduction +------------------ + +Create a New Docker Container +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: bash + + docker run \ + -it \ + --shm-size 32g \ + --gpus all \ + -v {Huggingface-Cache-Path}:/root/.cache \ + --ipc=host \ + --network=host \ + --privileged \ + --name sglang_{your-name} \ + lmsysorg/sglang:dev \ + /bin/zsh + +If you need to restart after exiting the container: + +.. code:: bash + + docker start -i sglang_{your-name} + +Update Python and Configure the Virtual Environment using uv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: bash + + apt update + apt install -y python3.10 python3.10-venv + + # Create a virtual environment + python3 -m venv ~/.python/verl-multiturn-rollout + + # Activate the virtual environment + source ~/.python/verl-multiturn-rollout/bin/activate + + # Install uv + python3 -m pip install uv + +Install verl Upstream +~~~~~~~~~~~~~~~~~~~~~ + +.. code:: bash + + cd ~ + git clone https://github.com/volcengine/verl.git + cd verl + + # Install verl + python3 -m uv pip install . + python3 -m uv pip install -r ./requirements_sglang.txt + + # Manually install flash-attn + python3 -m uv pip install wheel + python3 -m uv pip install packaging + python3 -m uv pip install flash-attn --no-build-isolation --no-deps + +Set Up a Local Retrieval Engine +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you are using your own local retrieval service, you can skip this +step. We chose the local dense retriever provided in the search-R1 +example; detailed instructions are in the `searchR1 +docs `__. +In brief: + +- The GPU version offers higher accuracy and speed; each GPU uses about + 5–7 GB of memory. +- The CPU version can be used for simple testing but has lower + retrieval precision, which will degrade training performance. See the + `retriever + documentation `__ + in search-R1 for details. +- Recommend using Conda to install faiss-gpu=1.8.0; venv may cause errors. + +**Note**: To start both the training process and the local retrieval +service, we launch two separate Python environments. The training uses +uv in the verl-multiturn-rollout environment, while the retriever uses +conda to install ``faiss-gpu``. + +.. code:: bash + + # Download the Miniconda installer script + wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh + + # Install to $HOME/miniconda3 in batch mode + bash ~/miniconda.sh -b -p $HOME/miniconda3 + + # Activate conda (only in the current shell) + eval "$($HOME/miniconda3/bin/conda shell.bash hook)" + + # (Optional) Add conda to your default shell startup + conda init + + # Reload shell config + source ~/.bashrc + + # Create and activate the retriever environment with Python 3.10 + conda create -n retriever python=3.10 -y + conda activate retriever + + # Install PyTorch (with GPU support) and related libraries + conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia -y + + # Install other Python packages + pip install transformers datasets pyserini huggingface_hub + + # Install the GPU version of faiss + conda install faiss-gpu=1.8.0 -c pytorch -c nvidia -y + + # Install the API service framework + pip install uvicorn fastapi + +Download the Indexing and Corpus +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The local retrieval files are large—prepare sufficient disk space. +Downloading is about 60–70 GB, and uncompressed takes about 132 GB: + +.. code:: bash + + conda activate retriever + + save_path=/the/path/to/save + python examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py --save_path $save_path + cat $save_path/part_* > $save_path/e5_Flat.index + gzip -d $save_path/wiki-18.jsonl.gz + +Start the Local flat e5 Retrieval Server +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +1. The first startup will download models and load the index. +2. Apart from the download, startup takes about 1–2 minutes. +3. After startup, each GPU uses about 5–7 GB of memory, leaving the rest + for multi-turn RL training. + +.. code:: bash + + conda activate retriever + + index_file=$save_path/e5_Flat.index + corpus_file=$save_path/wiki-18.jsonl + retriever_name=e5 + retriever_path=intfloat/e5-base-v2 + + python examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py \ + --index_path $index_file \ + --corpus_path $corpus_file \ + --topk 3 \ + --retriever_name $retriever_name \ + --retriever_model $retriever_path \ + --faiss_gpu + +Set Up WANDB_API_KEY +~~~~~~~~~~~~~~~~~~~~ + +.. code:: bash + + export WANDB_API_KEY={YOUR_WANDB_API_KEY} + + # Define a timestamp function + function now() { + date '+%Y-%m-%d-%H-%M' + } + +**Preprocess the Dataset** +~~~~~~~~~~~~~~~~~~~~~~~~~~ + + **Note:** The following data processing and training commands must be + run in the verl-multiturn-rollout environment. + +.. code:: bash + + python3 examples/data_preprocess/preprocess_search_r1_dataset.py + +Testing on 8 x H20 +~~~~~~~~~~~~~~~~~~ + +.. code:: bash + + # Ensure the now() function is defined + # Create a logs directory + mkdir -p logs + + # Set GPUs and run with a suitable log path + export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + + nohup bash examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh \ + trainer.experiment_name=qwen2.5-3b-it_rm-searchR1-like-sgl-multiturn-$(now) \ + > logs/searchR1-like$(now).log 2>&1 & + +Custom Search Configuration +--------------------------- + +To enable multi-turn reasoning, set the following fields in your config: + +.. code:: yaml + + actor_rollout_ref: + rollout: + name: "sglang_async" + multi_turn: + enable: True + +You must specify ``retrieval_service_url`` in ``examples/sglang_multiturn/config/tool_config/search_tool_config.yaml``, and properly configure concurrency. For more details on concurrency, refer to the Sandbox Fusion example: + +.. code:: yaml + + tools: + - class_name: verl.tools.search_tool.SearchTool + config: + retrieval_service_url: http://127.0.0.1:8000/retrieve + num_workers: 120 + rate_limit: 120 + timeout: 30 + +The retriever input/output formats are as follows. If your service +parameters match, only modify ``retrieval_service_url``. You can also +customize in ``search_r1_like_utils.py``. + +.. code:: python + + Input format: + { + "queries": ["What is Python?", "Tell me about neural networks."], + "topk": 3, + "return_scores": true + } + + Output format (when return_scores=True, similarity scores are returned): + { + "result": [ + [ # Results for each query + { + "document": doc, "score": score + }, + # ... more documents + ], + # ... results for other queries + ] + } + +Notes +----- + +1. The total training time is about 27 hours; meanwhile, the validation + dataset is very large (51 k), and each validation takes about 6000 s. + (Therefore, ``val_before_train=False`` by default) diff --git a/examples/data_preprocess/preprocess_search_r1_dataset.py b/examples/data_preprocess/preprocess_search_r1_dataset.py new file mode 100644 index 00000000000..a602d0203be --- /dev/null +++ b/examples/data_preprocess/preprocess_search_r1_dataset.py @@ -0,0 +1,168 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import tempfile + +import pandas as pd +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError + +from verl.utils.hdfs_io import copy, makedirs + +# Setup logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + +# Configuration constants +DEFAULT_SYSTEM_CONTENT = "You are a helpful and harmless assistant." +DEFAULT_USER_CONTENT_PREFIX = ( + "Answer the given question. You must conduct reasoning inside and " + "first every time you get new information. After reasoning, if you find you lack " + "some knowledge, you can call a search engine by query " + "and it will return the top searched results between and " + ". You can search as many times as your want. If you find no " + "further external knowledge needed, you can directly provide the answer inside " + " and , without detailed illustrations. For example, " + " Beijing . Question: " +) + + +def process_single_row(row, current_split_name, row_index): + """ + Process a single row of data for SearchR1-like format. + + Args: + row: DataFrame row containing the original data + current_split_name: Name of the current split (train/test) + row_index: Index of the row in the DataFrame + + Returns: + pd.Series: Processed row data in the required format + """ + question = row.get("question", "") + + # Build prompt structure + user_content = user_content_prefix.rstrip("\n") + question + prompt = [{"role": "system", "content": system_content}, {"role": "user", "content": user_content}] + + # Extract ground truth from reward_model or fallback to golden_answers + reward_model_data = row.get("reward_model") + if isinstance(reward_model_data, dict) and "ground_truth" in reward_model_data: + ground_truth = reward_model_data.get("ground_truth") + else: + ground_truth = row.get("golden_answers", []) + + # Process data source + data_source_tagged = "searchR1_" + str(row.get("data_source", "")) + + # Build tools kwargs structure + tools_kwargs = {"search": {"create_kwargs": {"ground_truth": ground_truth, "question": question, "data_source": data_source_tagged}}} + + # Build complete extra_info structure + extra_info = { + "index": row_index, + "need_tools_kwargs": True, + "question": question, + "split": current_split_name, + "tools_kwargs": tools_kwargs, + } + + return pd.Series( + { + "data_source": data_source_tagged, + "prompt": prompt, + "ability": row.get("ability"), + "reward_model": reward_model_data, + "extra_info": extra_info, + "metadata": row.get("metadata"), + } + ) + + +def main(): + local_save_dir = os.path.expanduser(args.local_dir) + os.makedirs(local_save_dir, exist_ok=True) + + processed_files = [] + + # Download and process files using temporary directory + with tempfile.TemporaryDirectory() as tmp_download_dir: + for split in ["train", "test"]: + parquet_filename = f"{split}.parquet" + logger.info(f"Processing {split} split...") + + try: + # Download Parquet file from HuggingFace + logger.info(f"Downloading {parquet_filename} from {args.hf_repo_id}") + local_parquet_filepath = hf_hub_download( + repo_id=args.hf_repo_id, + filename=parquet_filename, + repo_type="dataset", + local_dir=tmp_download_dir, + local_dir_use_symlinks=False, + ) + + # Load and process Parquet file + df_raw = pd.read_parquet(local_parquet_filepath) + logger.info(f"Loaded {len(df_raw)} rows from {parquet_filename}") + + def apply_process_row(row, split_name=split): + return process_single_row(row, current_split_name=split_name, row_index=row.name) + + df_processed = df_raw.apply(apply_process_row, axis=1) + + # Save processed DataFrame + output_file_path = os.path.join(local_save_dir, f"{split}.parquet") + df_processed.to_parquet(output_file_path, index=False) + logger.info(f"Saved {len(df_processed)} processed rows to {output_file_path}") + processed_files.append(output_file_path) + + except EntryNotFoundError: + logger.warning(f"{parquet_filename} not found in repository {args.hf_repo_id}") + except Exception as e: + logger.error(f"Error processing {split} split: {e}") + + if not processed_files: + logger.warning("No data was processed or saved") + return + + logger.info(f"Successfully processed {len(processed_files)} files to {local_save_dir}") + + # Copy to HDFS if specified + if args.hdfs_dir: + try: + makedirs(args.hdfs_dir) + copy(src=local_save_dir, dst=args.hdfs_dir) + logger.info(f"Successfully copied files to HDFS: {args.hdfs_dir}") + except Exception as e: + logger.error(f"Error copying files to HDFS: {e}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Download Search-R1 from HuggingFace, process, and save to Parquet.") + parser.add_argument("--hf_repo_id", default="PeterJinGo/nq_hotpotqa_train", help="HuggingFace dataset repository ID.") + parser.add_argument("--local_dir", default="~/data/searchR1_processed_direct", help="Local directory to save the processed Parquet files.") + parser.add_argument("--hdfs_dir", default=None, help="Optional HDFS directory to copy the Parquet files to.") + + args = parser.parse_args() + + # System and user content configuration + system_content = DEFAULT_SYSTEM_CONTENT + user_content_prefix = DEFAULT_USER_CONTENT_PREFIX + + main() diff --git a/examples/sft/gsm8k/run_qwen_05_peft.sh b/examples/sft/gsm8k/run_qwen_05_peft.sh old mode 100755 new mode 100644 diff --git a/examples/sft/gsm8k/run_qwen_05_sp2.sh b/examples/sft/gsm8k/run_qwen_05_sp2.sh old mode 100755 new mode 100644 diff --git a/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh b/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh old mode 100755 new mode 100644 diff --git a/examples/sft/multiturn/run_qwen_05_sp2.sh b/examples/sft/multiturn/run_qwen_05_sp2.sh old mode 100755 new mode 100644 diff --git a/examples/sglang_multiturn/config/search_multiturn_grpo.yaml b/examples/sglang_multiturn/config/search_multiturn_grpo.yaml new file mode 100644 index 00000000000..0c18ecf5110 --- /dev/null +++ b/examples/sglang_multiturn/config/search_multiturn_grpo.yaml @@ -0,0 +1,23 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + max_prompt_length: 1024 + max_response_length: 1024 + train_batch_size: 256 + return_raw_chat: True + shuffle: False + +actor_rollout_ref: + hybrid_engine: True + rollout: + name: sglang_async + multi_turn: + enable: True + max_turns: 2 + format: qwen diff --git a/examples/sglang_multiturn/config/tool_config/search_tool_config.yaml b/examples/sglang_multiturn/config/tool_config/search_tool_config.yaml new file mode 100644 index 00000000000..79b647e6207 --- /dev/null +++ b/examples/sglang_multiturn/config/tool_config/search_tool_config.yaml @@ -0,0 +1,22 @@ +tools: + - class_name: verl.tools.search_tool.SearchTool + config: + retrieval_service_url: http://127.0.0.1:8000/retrieve + num_workers: 120 + rate_limit: 120 + timeout: 30 + tool_schema: + type: function + function: + name: search + description: Searches the web for relevant information based on the given query. + parameters: + type: object + properties: + query_list: + type: array + item: + type: string + description: A list of fully-formed semantic queries. The tool will return search results for each query. + required: + - query_list \ No newline at end of file diff --git a/examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py b/examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py new file mode 100644 index 00000000000..6fe554936fa --- /dev/null +++ b/examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py @@ -0,0 +1,44 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 Search-R1 Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/scripts/download.py + + +import argparse + +from huggingface_hub import hf_hub_download + +parser = argparse.ArgumentParser(description="Download files from a Hugging Face dataset repository.") +parser.add_argument("--repo_id", type=str, default="PeterJinGo/wiki-18-e5-index", help="Hugging Face repository ID") +parser.add_argument("--save_path", type=str, required=True, help="Local directory to save files") + +args = parser.parse_args() + +repo_id = "PeterJinGo/wiki-18-e5-index" +for file in ["part_aa", "part_ab"]: + hf_hub_download( + repo_id=repo_id, + filename=file, # e.g., "e5_Flat.index" + repo_type="dataset", + local_dir=args.save_path, + ) + +repo_id = "PeterJinGo/wiki-18-corpus" +hf_hub_download( + repo_id=repo_id, + filename="wiki-18.jsonl.gz", + repo_type="dataset", + local_dir=args.save_path, +) diff --git a/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py b/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py new file mode 100644 index 00000000000..46d4184a464 --- /dev/null +++ b/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py @@ -0,0 +1,383 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 Search-R1 Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/search_r1/search/retrieval_server.py + +import argparse +import json +import warnings +from typing import List, Optional + +import datasets +import faiss +import numpy as np +import torch +import uvicorn +from fastapi import FastAPI +from pydantic import BaseModel +from tqdm import tqdm +from transformers import AutoModel, AutoTokenizer + + +def load_corpus(corpus_path: str): + corpus = datasets.load_dataset("json", data_files=corpus_path, split="train", num_proc=4) + return corpus + + +def load_docs(corpus, doc_idxs): + results = [corpus[int(idx)] for idx in doc_idxs] + return results + + +def load_model(model_path: str, use_fp16: bool = False): + model = AutoModel.from_pretrained(model_path, trust_remote_code=True) + model.eval() + model.cuda() + if use_fp16: + model = model.half() + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True) + return model, tokenizer + + +def pooling(pooler_output, last_hidden_state, attention_mask=None, pooling_method="mean"): + if pooling_method == "mean": + last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + elif pooling_method == "cls": + return last_hidden_state[:, 0] + elif pooling_method == "pooler": + return pooler_output + else: + raise NotImplementedError("Pooling method not implemented!") + + +class Encoder: + def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16): + self.model_name = model_name + self.model_path = model_path + self.pooling_method = pooling_method + self.max_length = max_length + self.use_fp16 = use_fp16 + + self.model, self.tokenizer = load_model(model_path=model_path, use_fp16=use_fp16) + self.model.eval() + + @torch.no_grad() + def encode(self, query_list: List[str], is_query=True) -> np.ndarray: + # processing query for different encoders + if isinstance(query_list, str): + query_list = [query_list] + + if "e5" in self.model_name.lower(): + if is_query: + query_list = [f"query: {query}" for query in query_list] + else: + query_list = [f"passage: {query}" for query in query_list] + + if "bge" in self.model_name.lower(): + if is_query: + query_list = [f"Represent this sentence for searching relevant passages: {query}" for query in query_list] + + inputs = self.tokenizer(query_list, max_length=self.max_length, padding=True, truncation=True, return_tensors="pt") + inputs = {k: v.cuda() for k, v in inputs.items()} + + if "T5" in type(self.model).__name__: + # T5-based retrieval model + decoder_input_ids = torch.zeros((inputs["input_ids"].shape[0], 1), dtype=torch.long).to(inputs["input_ids"].device) + output = self.model(**inputs, decoder_input_ids=decoder_input_ids, return_dict=True) + query_emb = output.last_hidden_state[:, 0, :] + else: + output = self.model(**inputs, return_dict=True) + query_emb = pooling(output.pooler_output, output.last_hidden_state, inputs["attention_mask"], self.pooling_method) + if "dpr" not in self.model_name.lower(): + query_emb = torch.nn.functional.normalize(query_emb, dim=-1) + + query_emb = query_emb.detach().cpu().numpy() + query_emb = query_emb.astype(np.float32, order="C") + + del inputs, output + torch.cuda.empty_cache() + + return query_emb + + +class BaseRetriever: + def __init__(self, config): + self.config = config + self.retrieval_method = config.retrieval_method + self.topk = config.retrieval_topk + + self.index_path = config.index_path + self.corpus_path = config.corpus_path + + def _search(self, query: str, num: int, return_score: bool): + raise NotImplementedError + + def _batch_search(self, query_list: List[str], num: int, return_score: bool): + raise NotImplementedError + + def search(self, query: str, num: int = None, return_score: bool = False): + return self._search(query, num, return_score) + + def batch_search(self, query_list: List[str], num: int = None, return_score: bool = False): + return self._batch_search(query_list, num, return_score) + + +class BM25Retriever(BaseRetriever): + def __init__(self, config): + super().__init__(config) + from pyserini.search.lucene import LuceneSearcher + + self.searcher = LuceneSearcher(self.index_path) + self.contain_doc = self._check_contain_doc() + if not self.contain_doc: + self.corpus = load_corpus(self.corpus_path) + self.max_process_num = 8 + + def _check_contain_doc(self): + return self.searcher.doc(0).raw() is not None + + def _search(self, query: str, num: int = None, return_score: bool = False): + if num is None: + num = self.topk + hits = self.searcher.search(query, num) + if len(hits) < 1: + if return_score: + return [], [] + else: + return [] + scores = [hit.score for hit in hits] + if len(hits) < num: + warnings.warn("Not enough documents retrieved!", stacklevel=2) + else: + hits = hits[:num] + + if self.contain_doc: + all_contents = [json.loads(self.searcher.doc(hit.docid).raw())["contents"] for hit in hits] + results = [{"title": content.split("\n")[0].strip('"'), "text": "\n".join(content.split("\n")[1:]), "contents": content} for content in all_contents] + else: + results = load_docs(self.corpus, [hit.docid for hit in hits]) + + if return_score: + return results, scores + else: + return results + + def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False): + results = [] + scores = [] + for query in query_list: + item_result, item_score = self._search(query, num, True) + results.append(item_result) + scores.append(item_score) + if return_score: + return results, scores + else: + return results + + +class DenseRetriever(BaseRetriever): + def __init__(self, config): + super().__init__(config) + self.index = faiss.read_index(self.index_path) + if config.faiss_gpu: + co = faiss.GpuMultipleClonerOptions() + co.useFloat16 = True + co.shard = True + self.index = faiss.index_cpu_to_all_gpus(self.index, co=co) + + self.corpus = load_corpus(self.corpus_path) + self.encoder = Encoder(model_name=self.retrieval_method, model_path=config.retrieval_model_path, pooling_method=config.retrieval_pooling_method, max_length=config.retrieval_query_max_length, use_fp16=config.retrieval_use_fp16) + self.topk = config.retrieval_topk + self.batch_size = config.retrieval_batch_size + + def _search(self, query: str, num: int = None, return_score: bool = False): + if num is None: + num = self.topk + query_emb = self.encoder.encode(query) + scores, idxs = self.index.search(query_emb, k=num) + idxs = idxs[0] + scores = scores[0] + results = load_docs(self.corpus, idxs) + if return_score: + return results, scores.tolist() + else: + return results + + def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False): + if isinstance(query_list, str): + query_list = [query_list] + if num is None: + num = self.topk + + results = [] + scores = [] + for start_idx in tqdm(range(0, len(query_list), self.batch_size), desc="Retrieval process: "): + query_batch = query_list[start_idx : start_idx + self.batch_size] + batch_emb = self.encoder.encode(query_batch) + batch_scores, batch_idxs = self.index.search(batch_emb, k=num) + batch_scores = batch_scores.tolist() + batch_idxs = batch_idxs.tolist() + + # load_docs is not vectorized, but is a python list approach + flat_idxs = sum(batch_idxs, []) + batch_results = load_docs(self.corpus, flat_idxs) + # chunk them back + batch_results = [batch_results[i * num : (i + 1) * num] for i in range(len(batch_idxs))] + + results.extend(batch_results) + scores.extend(batch_scores) + + del batch_emb, batch_scores, batch_idxs, query_batch, flat_idxs, batch_results + torch.cuda.empty_cache() + + if return_score: + return results, scores + else: + return results + + +def get_retriever(config): + if config.retrieval_method == "bm25": + return BM25Retriever(config) + else: + return DenseRetriever(config) + + +##################################### +# FastAPI server below +##################################### + + +class Config: + """ + Minimal config class (simulating your argparse) + Replace this with your real arguments or load them dynamically. + """ + + def __init__( + self, + retrieval_method: str = "bm25", + retrieval_topk: int = 10, + index_path: str = "./index/bm25", + corpus_path: str = "./data/corpus.jsonl", + dataset_path: str = "./data", + data_split: str = "train", + faiss_gpu: bool = True, + retrieval_model_path: str = "./model", + retrieval_pooling_method: str = "mean", + retrieval_query_max_length: int = 256, + retrieval_use_fp16: bool = False, + retrieval_batch_size: int = 128, + ): + self.retrieval_method = retrieval_method + self.retrieval_topk = retrieval_topk + self.index_path = index_path + self.corpus_path = corpus_path + self.dataset_path = dataset_path + self.data_split = data_split + self.faiss_gpu = faiss_gpu + self.retrieval_model_path = retrieval_model_path + self.retrieval_pooling_method = retrieval_pooling_method + self.retrieval_query_max_length = retrieval_query_max_length + self.retrieval_use_fp16 = retrieval_use_fp16 + self.retrieval_batch_size = retrieval_batch_size + + +class QueryRequest(BaseModel): + queries: List[str] + topk: Optional[int] = None + return_scores: bool = False + + +app = FastAPI() + + +@app.post("/retrieve") +def retrieve_endpoint(request: QueryRequest): + """ + Endpoint that accepts queries and performs retrieval. + + Input format: + { + "queries": ["What is Python?", "Tell me about neural networks."], + "topk": 3, + "return_scores": true + } + + Output format (when return_scores=True,similarity scores are returned): + { + "result": [ + [ # Results for each query + { + {"document": doc, "score": score} + }, + # ... more documents + ], + # ... results for other queries + ] + } + """ + if not request.topk: + request.topk = config.retrieval_topk # fallback to default + + # Perform batch retrieval + results, scores = retriever.batch_search(query_list=request.queries, num=request.topk, return_score=request.return_scores) + + # Format response + resp = [] + for i, single_result in enumerate(results): + if request.return_scores: + # If scores are returned, combine them with results + combined = [] + for doc, score in zip(single_result, scores[i]): + combined.append({"document": doc, "score": score}) + resp.append(combined) + else: + resp.append(single_result) + return {"result": resp} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Launch the local faiss retriever.") + parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.") + parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.") + parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.") + parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.") + parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model.") + parser.add_argument("--faiss_gpu", action="store_true", help="Use GPU for computation") + + args = parser.parse_args() + + # 1) Build a config (could also parse from arguments). + # In real usage, you'd parse your CLI arguments or environment variables. + config = Config( + retrieval_method=args.retriever_name, # or "dense" + index_path=args.index_path, + corpus_path=args.corpus_path, + retrieval_topk=args.topk, + faiss_gpu=args.faiss_gpu, + retrieval_model_path=args.retriever_model, + retrieval_pooling_method="mean", + retrieval_query_max_length=256, + retrieval_use_fp16=True, + retrieval_batch_size=512, + ) + + # 2) Instantiate a global retriever so it is loaded once and reused. + retriever = get_retriever(config) + + # 3) Launch the server. By default, it listens on http://127.0.0.1:8000 + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh b/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh new file mode 100644 index 00000000000..11becfce7fd --- /dev/null +++ b/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh @@ -0,0 +1,66 @@ +# run on 8xH20 +# make sure your current working directory is the root of the project + +set -x + +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" + + +TRAIN_DATA="$HOME/data/searchR1_processed_direct/train.parquet" +VAL_DATA="$HOME/data/searchR1_processed_direct/test.parquet" + +TOOL_CONFIG="$CONFIG_PATH/tool_config/search_tool_config.yaml" + + + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='search_multiturn_grpo' \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=512 \ + data.val_batch_size=256 \ + data.max_prompt_length=4096 \ + data.max_response_length=3000 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.285 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.max_model_len=15000 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=sglang_async \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.multi_turn.max_turns=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.val_before_train=False \ + trainer.logger=['console','wandb'] \ + trainer.project_name='search_r1_like_async_rl' \ + trainer.experiment_name='qwen2.5-3b-instruct_function_rm-search-async-sgl-multi-w-searchtool-verify-n16' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=100 \ + trainer.test_freq=50 \ + data.train_files="$TRAIN_DATA" \ + data.val_files="$VAL_DATA" \ + actor_rollout_ref.rollout.multi_turn.tool_config_path="$TOOL_CONFIG" \ + trainer.total_epochs=1 $@ + diff --git a/requirements_sglang.txt b/requirements_sglang.txt index 470851a4bf7..2f99c97862c 100644 --- a/requirements_sglang.txt +++ b/requirements_sglang.txt @@ -18,4 +18,5 @@ torchvision transformers wandb sglang[all]==0.4.6.post4 -torch-memory-saver>=0.0.5 \ No newline at end of file +torch-memory-saver>=0.0.5 +huggingface_hub \ No newline at end of file diff --git a/tests/workers/rollout/resource/tool_configs/search_tool_config b/tests/workers/rollout/resource/tool_configs/search_tool_config new file mode 100644 index 00000000000..79b647e6207 --- /dev/null +++ b/tests/workers/rollout/resource/tool_configs/search_tool_config @@ -0,0 +1,22 @@ +tools: + - class_name: verl.tools.search_tool.SearchTool + config: + retrieval_service_url: http://127.0.0.1:8000/retrieve + num_workers: 120 + rate_limit: 120 + timeout: 30 + tool_schema: + type: function + function: + name: search + description: Searches the web for relevant information based on the given query. + parameters: + type: object + properties: + query_list: + type: array + item: + type: string + description: A list of fully-formed semantic queries. The tool will return search results for each query. + required: + - query_list \ No newline at end of file diff --git a/tests/workers/rollout/test_sglang_async_rollout_search_tools.py b/tests/workers/rollout/test_sglang_async_rollout_search_tools.py new file mode 100644 index 00000000000..655e6124a13 --- /dev/null +++ b/tests/workers/rollout/test_sglang_async_rollout_search_tools.py @@ -0,0 +1,345 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from tests/workers/rollout/test_sglang_async_rollout_sf_tools.py + + +import asyncio +from copy import deepcopy +from unittest.mock import AsyncMock, MagicMock, patch + +import numpy as np +import pytest +from tensordict import TensorDict +from transformers import AutoConfig, AutoTokenizer +from utils_sglang import ( + get_rollout_config, + prepare_inputs, +) + +from verl.protocol import DataProto +from verl.tools.schemas import OpenAIFunctionParametersSchema, OpenAIFunctionPropertySchema, OpenAIFunctionSchema, OpenAIFunctionToolSchema +from verl.tools.search_tool import SearchTool +from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message +from verl.workers.rollout.sglang_rollout.async_sglang_rollout import AsyncSGLangRollout + +DEFAULT_USER_CONTENT_PREFIX = ( + "Answer the given question. You must conduct reasoning inside and " + "first every time you get new information. After reasoning, if you find you lack " + "some knowledge, you can call a search engine by query " + "and it will return the top searched results between and " + ". You can search as many times as your want. If you find no " + "further external knowledge needed, you can directly provide the answer inside " + " and , without detailed illustrations. For example, " + " Beijing . Question: " +) +user_content = DEFAULT_USER_CONTENT_PREFIX.rstrip("\n") + "How's the weather lately?" + + +def get_search_messages(): + user_prompt = { + "role": "user", + "content": user_content, + } + + expect_turn_0_msg = { + "role": "assistant", + "content": "Let me search the web.", + "tool_calls": [{"type": "function", "function": {"name": "search", "arguments": {"query": "today's weather"}}}], + } + + expect_turn_1_msg = { + "role": "assistant", + "content": "Let me search again.", + "tool_calls": [{"type": "function", "function": {"name": "search", "arguments": {"query": "tomorrow's weather"}}}], + } + + expect_turn_2_msg = { + "role": "assistant", + "content": "Today is sunny and tomorrow will be cloudy in Beijing.", + } + + # Mock search tool responses + tool_return_0_msg = {"role": "tool", "content": "Today's weather in Beijing is sunny."} + tool_return_1_msg = {"role": "tool", "content": "Tomorrow's weather in Beijing is cloudy."} + + user_prompts = [user_prompt] + expect_turn_array = [expect_turn_0_msg, expect_turn_1_msg, expect_turn_2_msg] + tool_return_array = [tool_return_0_msg, tool_return_1_msg] + + return user_prompts, expect_turn_array, tool_return_array + + +class TestRolloutWithSearchTools: + @pytest.fixture + def qwen_tokenizer(self): + local_model_path = "Qwen/Qwen2.5-0.5B" + tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left") + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + # we only need this for tokenizer + @pytest.fixture + def qwen_model_config(self): + local_model_path = "Qwen/Qwen2.5-0.5B" + config = AutoConfig.from_pretrained(local_model_path) + return config + + @pytest.fixture + def search_data(self, qwen_tokenizer): + user_prompt, expect_turn_array, tool_return_array = get_search_messages() + prompts = [[message] for message in user_prompt] + preencode_turn_array = [qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=False) for turn in expect_turn_array] + preencode_tool_return_array = [qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=True) for turn in tool_return_array] + return prompts, preencode_turn_array, preencode_tool_return_array + + @pytest.fixture + def search_rollout_config(self): + max_prompt_length = 4096 + max_response_length = 3000 + dtype = "bfloat16" + tensor_parallel_size = 1 + tool_path = "./resource/tool_configs/search_tool_config" + rollout_config = get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_path) + return rollout_config + + @pytest.fixture + def search_data_proto(self, search_data, qwen_tokenizer): + preencode_prompts, _, _ = search_data + prompts = [qwen_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in preencode_prompts] + input_ids, attention_mask, position_ids = prepare_inputs(qwen_tokenizer, prompts, 1000) + prompt_dict = TensorDict( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + }, + batch_size=input_ids.shape[0], + ) + messages = np.asarray(preencode_prompts) + + tools_kwargs = np.array( + [ + { + "search": { + "create_kwargs": {"ground_truth": "Today is sunny and tomorrow will be cloudy in Beijing.", "data_source": "searchR1_nq"}, + }, + } + ], + dtype=object, + ) + index = np.array([0], dtype=object) + prompts = DataProto(batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": tools_kwargs, "index": index}) + return prompts + + @patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None) + @patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None) + @patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None) + def test_tools_registration(self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config): + rollout = AsyncSGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) + assert len(rollout._tool_schemas) == 1 + assert "search" in rollout._tool_map.keys() + from verl.tools.search_tool import SearchTool + + assert isinstance(rollout._tool_map["search"], SearchTool) + # depend on the tokenizer + assert rollout._tool_call_parser_type == "qwen25" + + @patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None) + @patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None) + @patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None) + def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto): + rollout = AsyncSGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) + req_list = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1) + assert len(req_list) == 1 + assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING + assert len(req_list[0].tools) == 1 + print(type(req_list[0].tools[0])) + assert req_list[0].tools[0] == OpenAIFunctionToolSchema( + type="function", + function=OpenAIFunctionSchema( + name="search", + description="Searches the web for relevant information based on the given query.", + parameters=OpenAIFunctionParametersSchema( + type="object", + properties={ + "query_list": OpenAIFunctionPropertySchema( + type="array", + description="A list of fully-formed semantic queries. The tool will return search results for each query.", + items={"type": "string"}, + ) + }, + required=["query_list"], + ), + strict=False, + ), + ) + + @patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None) + @patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None) + @patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None) + def test_over_size_case(self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto, search_data): + search_rollout_config.multi_turn.max_turns = 1 + rollout = AsyncSGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) + req = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] + req = MagicMock(wraps=req, spec=AsyncRolloutRequest) + req.finalize = MagicMock() + req_list = [req] + + _, expect_turn_array, _ = search_data + # here we mock a meta info with 'length'. indicate the response is truncate + rollout._handle_engine_call = MagicMock() + future = asyncio.Future() + future.set_result({"text": expect_turn_array[0], "meta_info": {"id": "d1188d81cba840359df5b352b344bc8e", "finish_reason": {"type": "length", "length": 3000}, "prompt_tokens": 132, "completion_tokens": 100, "cached_tokens": 0, "e2e_latency": 2.23543}}) + rollout._handle_engine_call.return_value = future + rollout._tp_rank = 0 + loop = asyncio.get_event_loop() + output_req_list = loop.run_until_complete( + asyncio.gather( + *[rollout._async_rollout_a_request(req, True, False) for req in req_list], + ) + ) + assert len(output_req_list) == 1 + output_req = output_req_list[0] + assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED + assert output_req.reward_scores == {"search": []}, f"output_req.reward_scores: {output_req.reward_scores}" + # we should only have two message, one for prompt, second for response. + assert len(output_req.messages) == 2 + assert output_req.messages[1] == Message( + role="assistant", + content=expect_turn_array[0], + tool_calls=None, + ) + + @patch.object(SearchTool, "execute", new_callable=AsyncMock) + @patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None) + @patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None) + @patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None) + def test_tool_call_basic_case(self, mock_sampling, mock_engine, mock_env, mock_execute, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto, search_data): + _, expect_turn_array, tool_return_array = search_data + + # Mock search tool execution to return predefined responses + mock_execute.side_effect = [(msg, 0.0, {"status": "success"}) for msg in tool_return_array] + + search_rollout_config.multi_turn.max_turns = 10 + rollout = AsyncSGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) + + rollout._tool_map["search"].retrieval_service_url = "mock://dummy" + + req = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] + req = MagicMock(wraps=req, spec=AsyncRolloutRequest) + req.finalize = MagicMock() + req_list = [req] + + rollout._handle_engine_call = MagicMock() + futures = [asyncio.Future() for i in expect_turn_array] + for idx, (i, turn) in enumerate(zip(futures, expect_turn_array)): + i.set_result({"text": turn, "meta_info": {"id": "d1188d81cba840359df5b352b344bc8e", "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, "prompt_tokens": len(turn), "completion_tokens": 100, "cached_tokens": 0, "e2e_latency": 2.23543}}) + if idx < len(expect_turn_array) - 1: + assert rollout._function_call_parser.has_tool_call(turn) + assert rollout._function_call_parser.parse_non_stream(turn) + + rollout._handle_engine_call.side_effect = futures + rollout._tp_rank = 0 + + loop = asyncio.get_event_loop() + output_req_list = loop.run_until_complete(asyncio.gather(*[rollout._async_rollout_a_request(req, True, False) for req in req_list])) + + # Verify conversation completed successfully with proper tool usage + output_req = output_req_list[0] + assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED + assert "search" in output_req.metrics + assert output_req.metrics["search"][0]["status"] == "success" + assert mock_execute.await_count == 2 + assert len(output_req.messages) == 6 # user + 3*assistant + 2*tool_call + # Verify tool response messages contain expected content + search_counter = 0 + for msg in output_req.messages: + if msg.role == "tool": + assert msg.content == tool_return_array[search_counter] + search_counter += 1 + assert search_counter == 2 + + @patch.object(SearchTool, "execute", new_callable=AsyncMock) + @patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None) + @patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None) + @patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None) + def test_tool_call_batch_case(self, mock_sampling, mock_engine, mock_env, mock_execute, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto, search_data): + _, expect_turn_array, tool_return_array = search_data + + # Mock tool execution for large batch (100 requests * 2 calls each) + mock_execute.side_effect = [ + (tool_return_array[0], 0.0, {"status": "success"}), + (tool_return_array[1], 0.0, {"status": "success"}), + ] * 100 + + search_rollout_config.multi_turn.max_turns = 10 + rollout = AsyncSGLangRollout( + actor_module="", + config=search_rollout_config, + tokenizer=qwen_tokenizer, + model_hf_config=qwen_model_config, + ) + rollout._tool_map["search"].retrieval_service_url = "mock://dummy" + + base_req = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] + + req_nums = 100 + req_list = [] + req_turns_map = {} + req_turns_counter = {} + + for i in range(req_nums): + tmp_req = deepcopy(base_req) + tmp_req.batch_data_id = i + tmp_req.request_id = i + req_list.append(MagicMock(wraps=tmp_req, spec=AsyncRolloutRequest)) + + futures = [asyncio.Future() for _ in expect_turn_array] + for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array)): + fut.set_result( + { + "text": turn, + "meta_info": { + "id": "dummy", + "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, + "prompt_tokens": len(turn), + "completion_tokens": 100, + }, + } + ) + req_turns_map[i] = futures + req_turns_counter[i] = 0 + + async def hacked_handle_engine_call(self, _req: AsyncRolloutRequest, *_args, **_kwargs): + fut = req_turns_map[_req.batch_data_id][req_turns_counter[_req.batch_data_id]] + req_turns_counter[_req.batch_data_id] += 1 + return await fut + + with patch.object(AsyncSGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call): + rollout._tp_rank = 0 + loop = asyncio.get_event_loop() + output_req_list = loop.run_until_complete(asyncio.gather(*[rollout._async_rollout_a_request(r, True, False) for r in req_list])) + + # Verify all requests completed successfully + assert len(output_req_list) == req_nums + for out_req in output_req_list: + assert out_req.state == AsyncRolloutRequestStateEnum.COMPLETED + assert "search" in out_req.metrics + for metric in out_req.metrics["search"]: + assert metric["status"] == "success" + assert len(out_req.messages) == 6 # user + 3 assistant + 2 tool + assert sum(1 for m in out_req.messages if m.role == "tool") == 2 + + assert mock_execute.await_count == 2 * req_nums diff --git a/verl/tools/search_tool.py b/verl/tools/search_tool.py new file mode 100644 index 00000000000..b66200a4312 --- /dev/null +++ b/verl/tools/search_tool.py @@ -0,0 +1,258 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import threading +from contextlib import ExitStack +from enum import Enum +from typing import Any, Callable, Optional, Tuple, TypeVar +from uuid import uuid4 + +import ray +import ray.actor + +from verl.tools.utils.search_r1_like_utils import perform_single_search_batch + +from .base_tool import BaseTool +from .schemas import OpenAIFunctionToolSchema + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +T = TypeVar("T") + + +# Adapted from verl/tools/sandbox_fusion_tools.py +class PoolMode(Enum): + """Execution pool mode enumeration.""" + + ThreadMode = 1 + ProcessMode = 2 + + +@ray.remote(concurrency_groups={"acquire": 1, "release": 10}) +class TokenBucketWorker: + """Ray actor for rate limiting using token bucket algorithm.""" + + def __init__(self, rate_limit: int): + self.rate_limit = rate_limit + self.current_count = 0 # For observability + self._semaphore = threading.Semaphore(rate_limit) + + @ray.method(concurrency_group="acquire") + def acquire(self): + """Acquire a token from the bucket.""" + self._semaphore.acquire() + self.current_count += 1 + + @ray.method(concurrency_group="release") + def release(self): + """Release a token back to the bucket.""" + self._semaphore.release() + self.current_count -= 1 + + def get_current_count(self): + """Get current number of acquired tokens.""" + return self.current_count + + +class SearchExecutionWorker: + """Worker for executing search operations with optional rate limiting.""" + + def __init__(self, enable_global_rate_limit=True, rate_limit=10): + self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None + + def _init_rate_limit(self, rate_limit): + """Initialize singleton rate limiter.""" + return TokenBucketWorker.options(name="rate-limiter", get_if_exists=True).remote(rate_limit) + + def ping(self): + """Health check method.""" + return True + + def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T: + """Execute function with optional rate limiting.""" + if self.rate_limit_worker: + with ExitStack() as stack: + stack.callback(self.rate_limit_worker.release.remote) + ray.get(self.rate_limit_worker.acquire.remote()) + try: + return fn(*fn_args, **fn_kwargs) + except Exception as e: + # TODO we should make this available to the tool caller + logger.warning(f"Error when executing search: {e}") + else: + return fn(*fn_args, **fn_kwargs) + + +def init_search_execution_pool(num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode): + """Initialize search execution pool.""" + if mode == PoolMode.ThreadMode: + return ray.remote(SearchExecutionWorker).options(max_concurrency=num_workers).remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit) + else: + raise NotImplementedError("Process mode is not implemented yet") + + +class SearchTool(BaseTool): + """Search tool for retrieving information using external retrieval services. + + This tool provides search functionality with rate limiting and concurrent execution + support through Ray. It integrates with external retrieval services to perform + semantic search operations. + + Methods: + get_openai_tool_schema: Return the tool schema in OpenAI format + create: Create a tool instance for a trajectory + execute: Execute the search tool + calc_reward: Calculate the reward with respect to tool state + release: Release the tool instance + """ + + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + """Initialize SearchTool with configuration and schema. + + Args: + config: Configuration dictionary containing tool settings + tool_schema: OpenAI function tool schema definition + + Example tool_schema: + { + "type": "function", + "function": { + "name": "search", + "description": "Searches for relevant information based on queries.", + "parameters": { + "type": "object", + "properties": { + "query_list": { + "type": "array", + "items": {"type": "string"}, + "description": "List of search queries" + } + }, + "required": ["query_list"] + } + } + } + """ + super().__init__(config, tool_schema) + self._instance_dict = {} + + # Worker and rate limiting configuration + self.num_workers = config.get("num_workers", 120) + self.rate_limit = config.get("rate_limit", 120) + self.timeout = config.get("timeout", 30) + + self.enable_global_rate_limit = config.get("enable_global_rate_limit", True) + self.execution_pool = init_search_execution_pool(num_workers=self.num_workers, enable_global_rate_limit=self.enable_global_rate_limit, rate_limit=self.rate_limit, mode=PoolMode.ThreadMode) + + # Retrieval service configuration + self.retrieval_service_url = config.get("retrieval_service_url") + assert self.retrieval_service_url, "Configuration must include 'retrieval_service_url'" + self.topk = config.get("topk", 3) + if self.retrieval_service_url == "": + raise ValueError("retrieval_service_url is not set") + + logger.info(f"Initialized SearchTool with config: {config}") + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + """Return the OpenAI tool schema.""" + return self.tool_schema + + async def create(self, instance_id: Optional[str] = None, **kwargs) -> str: + """Create a tool instance. + + Args: + instance_id: The instance id of the tool. + + Returns: + The instance id of the tool. + """ + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "reward": [], + } + return instance_id + + def execute_search(self, instance_id: str, query_list: list, retrieval_service_url: str, topk: int, timeout: int): + """Execute search operation using retrieval service. + + Args: + instance_id: Tool instance ID + query_list: List of search queries + retrieval_service_url: URL of the retrieval service + topk: Number of top results to return + timeout: Request timeout in seconds + + Returns: + Tuple of (result_text, metadata) + """ + result_text, metadata = perform_single_search_batch( + retrieval_service_url=retrieval_service_url, + query_list=query_list, + topk=topk, + concurrent_semaphore=None, # Ray handles concurrency control + timeout=timeout, + ) + logger.debug(f"Search result for instance {instance_id}: {result_text}") + return result_text, metadata + + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]: + """Execute the search tool. + + Args: + instance_id: The instance ID of the tool + parameters: Tool parameters containing query_list and optional timeout + + Returns: tool_response, tool_reward_score, tool_metrics + tool_response: The response str of the tool. + tool_reward_score: The step reward score of the tool. + tool_metrics: The metrics of the tool. + """ + timeout = self.timeout + query_list_from_params = parameters.get("query_list") + + if not query_list_from_params or not isinstance(query_list_from_params, list): + error_msg = "Error: 'query_list' is missing, empty, or not a list in parameters." + logger.error(f"[SearchTool] {error_msg} Received parameters: {parameters}") + return json.dumps({"result": error_msg}), 0.0, {} + + # Execute search using Ray execution pool + try: + result_text, metadata = await self.execution_pool.execute.remote(self.execute_search, instance_id, query_list_from_params, self.retrieval_service_url, self.topk, timeout) + + # Store results in instance dictionary + self._instance_dict[instance_id]["reward"].append(result_text.strip()) + + # Convert metadata to metrics + metrics = {"query_count": metadata.get("query_count", 0), "status": metadata.get("status", "unknown"), "total_results": metadata.get("total_results", 0), "api_request_error": metadata.get("api_request_error")} + + return result_text, 0.0, metrics + + except Exception as e: + error_result = json.dumps({"result": f"Search execution failed: {e}"}) + logger.error(f"[SearchTool] Execution failed: {e}") + return error_result, 0.0, {"error": str(e)} + + async def calc_reward(self, instance_id: str, **kwargs) -> str: + return self._instance_dict[instance_id]["reward"] + + async def release(self, instance_id: str, **kwargs) -> None: + if instance_id in self._instance_dict: + del self._instance_dict[instance_id] diff --git a/verl/tools/utils/__init__.py b/verl/tools/utils/__init__.py new file mode 100644 index 00000000000..c4b932b1ae7 --- /dev/null +++ b/verl/tools/utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/tools/utils/search_r1_like_utils.py b/verl/tools/utils/search_r1_like_utils.py new file mode 100644 index 00000000000..8a3bb1bbaab --- /dev/null +++ b/verl/tools/utils/search_r1_like_utils.py @@ -0,0 +1,214 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import threading +import time +import traceback +import uuid +from typing import Any, Dict, List, Optional, Tuple + +import requests + +DEFAULT_TIMEOUT = 30 # Default search request timeout +MAX_RETRIES = 10 +INITIAL_RETRY_DELAY = 1 +API_TIMEOUT = 10 + +logger = logging.getLogger(__name__) + + +def call_search_api(retrieval_service_url: str, query_list: List[str], topk: int = 3, return_scores: bool = True, timeout: int = DEFAULT_TIMEOUT) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: + """ + Calls the remote search API to perform retrieval with retry logic for various errors, + using increasing delay between retries. Logs internal calls with a unique ID. + + Args: + retrieval_service_url: The URL of the retrieval service API. + query_list: List of search queries. + topk: Number of top results to return. + return_scores: Whether to return scores. + timeout: Request timeout in seconds. + + Returns: + A tuple (response_json, error_message). + If successful, response_json is the API's returned JSON object, error_message is None. + If failed after retries, response_json is None, error_message contains the error information. + """ + request_id = str(uuid.uuid4()) + log_prefix = f"[Search Request ID: {request_id}] " + + payload = {"queries": query_list, "topk": topk, "return_scores": return_scores} + + headers = {"Content-Type": "application/json", "Accept": "application/json"} + + last_error = None + + for attempt in range(MAX_RETRIES): + try: + logger.info(f"{log_prefix}Attempt {attempt + 1}/{MAX_RETRIES}: Calling search API at {retrieval_service_url}") + response = requests.post( + retrieval_service_url, + headers=headers, + json=payload, + timeout=timeout, + ) + + # Check for Gateway Timeout (504) and other server errors for retrying + if response.status_code in [500, 502, 503, 504]: + last_error = f"{log_prefix}API Request Error: Server Error ({response.status_code}) on attempt {attempt + 1}/{MAX_RETRIES}" + logger.warning(last_error) + if attempt < MAX_RETRIES - 1: + delay = INITIAL_RETRY_DELAY * (attempt + 1) + logger.info(f"{log_prefix}Retrying after {delay} seconds...") + time.sleep(delay) + continue + + # Check for other HTTP errors (e.g., 4xx) + response.raise_for_status() + + # If successful (status code 2xx) + logger.info(f"{log_prefix}Search API call successful on attempt {attempt + 1}") + return response.json(), None + + except requests.exceptions.ConnectionError as e: + last_error = f"{log_prefix}Connection Error: {e}" + logger.warning(last_error) + if attempt < MAX_RETRIES - 1: + delay = INITIAL_RETRY_DELAY * (attempt + 1) + logger.info(f"{log_prefix}Retrying after {delay} seconds...") + time.sleep(delay) + continue + except requests.exceptions.Timeout as e: + last_error = f"{log_prefix}Timeout Error: {e}" + logger.warning(last_error) + if attempt < MAX_RETRIES - 1: + delay = INITIAL_RETRY_DELAY * (attempt + 1) + logger.info(f"{log_prefix}Retrying after {delay} seconds...") + time.sleep(delay) + continue + except requests.exceptions.RequestException as e: + last_error = f"{log_prefix}API Request Error: {e}" + break # Exit retry loop on other request errors + except json.JSONDecodeError as e: + raw_response_text = response.text if "response" in locals() else "N/A" + last_error = f"{log_prefix}API Response JSON Decode Error: {e}, Response: {raw_response_text[:200]}" + break # Exit retry loop on JSON decode errors + except Exception as e: + last_error = f"{log_prefix}Unexpected Error: {e}" + break # Exit retry loop on other unexpected errors + + # If loop finishes without returning success, return the last recorded error + logger.error(f"{log_prefix}Search API call failed. Last error: {last_error}") + return None, last_error.replace(log_prefix, "API Call Failed: ") if last_error else "API Call Failed after retries" + + +def _passages2string(retrieval_result): + """Convert retrieval results to formatted string.""" + format_reference = "" + for idx, doc_item in enumerate(retrieval_result): + content = doc_item["document"]["contents"] + title = content.split("\n")[0] + text = "\n".join(content.split("\n")[1:]) + format_reference += f"Doc {idx + 1} (Title: {title})\n{text}\n\n" + return format_reference.strip() + + +def perform_single_search_batch(retrieval_service_url: str, query_list: List[str], topk: int = 3, concurrent_semaphore: Optional[threading.Semaphore] = None, timeout: int = DEFAULT_TIMEOUT) -> Tuple[str, Dict[str, Any]]: + """ + Performs a single batch search for multiple queries (original search tool behavior). + + Args: + retrieval_service_url: The URL of the retrieval service API. + query_list: List of search queries. + topk: Number of top results to return. + concurrent_semaphore: Optional semaphore for concurrency control. + timeout: Request timeout in seconds. + + Returns: + A tuple (result_text, metadata). + result_text: The search result JSON string. + metadata: Metadata dictionary for the batch search. + """ + logger.info(f"Starting batch search for {len(query_list)} queries.") + + api_response = None + error_msg = None + + try: + if concurrent_semaphore: + with concurrent_semaphore: + api_response, error_msg = call_search_api(retrieval_service_url=retrieval_service_url, query_list=query_list, topk=topk, return_scores=True, timeout=timeout) + else: + api_response, error_msg = call_search_api(retrieval_service_url=retrieval_service_url, query_list=query_list, topk=topk, return_scores=True, timeout=timeout) + except Exception as e: + error_msg = f"API Request Exception during batch search: {e}" + logger.error(f"Batch search: {error_msg}") + traceback.print_exc() + + metadata = { + "query_count": len(query_list), + "queries": query_list, + "api_request_error": error_msg, + "api_response": None, + "status": "unknown", + "total_results": 0, + "formatted_result": None, + } + + result_text = json.dumps({"result": "Search request failed or timed out after retries."}) + + if error_msg: + metadata["status"] = "api_error" + result_text = json.dumps({"result": f"Search error: {error_msg}"}) + logger.error(f"Batch search: API error occurred: {error_msg}") + elif api_response: + logger.debug(f"Batch search: API Response: {api_response}") + metadata["api_response"] = api_response + + try: + raw_results = api_response.get("result", []) + if raw_results: + pretty_results = [] + total_results = 0 + + for retrieval in raw_results: + formatted = _passages2string(retrieval) + pretty_results.append(formatted) + total_results += len(retrieval) if isinstance(retrieval, list) else 1 + + final_result = "\n---\n".join(pretty_results) + result_text = json.dumps({"result": final_result}) + metadata["status"] = "success" + metadata["total_results"] = total_results + metadata["formatted_result"] = final_result + logger.info(f"Batch search: Successful, got {total_results} total results") + else: + result_text = json.dumps({"result": "No search results found."}) + metadata["status"] = "no_results" + metadata["total_results"] = 0 + logger.info("Batch search: No results found") + except Exception as e: + error_msg = f"Error processing search results: {e}" + result_text = json.dumps({"result": error_msg}) + metadata["status"] = "processing_error" + logger.error(f"Batch search: {error_msg}") + else: + metadata["status"] = "unknown_api_state" + result_text = json.dumps({"result": "Unknown API state (no response and no error message)."}) + logger.error("Batch search: Unknown API state.") + + return result_text, metadata diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index 1466e498d88..16f22e137cd 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -79,6 +79,10 @@ def default_compute_score(data_source, solution_str, ground_truth, extra_info=No from . import geo3k res = geo3k.compute_score(solution_str, ground_truth) + elif data_source in ["searchR1_nq", "searchR1_triviaqa", "searchR1_popqa", "searchR1_hotpotqa", "searchR1_2wikimultihopqa", "searchR1_musique", "searchR1_bamboogle"]: + from . import search_r1_like_qa_em + + res = search_r1_like_qa_em.compute_score(solution_str, ground_truth) else: raise NotImplementedError(f"Reward function is not implemented for {data_source=}") diff --git a/verl/utils/reward_score/search_r1_like_qa_em.py b/verl/utils/reward_score/search_r1_like_qa_em.py new file mode 100644 index 00000000000..56782fcb343 --- /dev/null +++ b/verl/utils/reward_score/search_r1_like_qa_em.py @@ -0,0 +1,156 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 Search-R1 Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/verl/utils/reward_score/qa_em.py + +import random +import re +import string + + +def normalize_answer(s): + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def em_check(prediction, golden_answers): + if isinstance(golden_answers, str): + golden_answers = [golden_answers] + normalized_prediction = normalize_answer(prediction) + score = 0 + for golden_answer in golden_answers: + golden_answer = normalize_answer(golden_answer) + if golden_answer == normalized_prediction: + score = 1 + break + return score + + +def subem_check(prediction, golden_answers): + if isinstance(golden_answers, str): + golden_answers = [golden_answers] + normalized_prediction = normalize_answer(prediction) + score = 0 + for golden_answer in golden_answers: + golden_answer = normalize_answer(golden_answer) + if golden_answer in normalized_prediction: + score = 1 + break + return score + + +def extract_solution(solution_str): + """Extract the equation from the solution string.""" + # Remove everything before the first "Assistant:" + # if "Assistant:" in solution_str: + # solution_str = solution_str.split("Assistant:", 1)[1] + # elif "<|im_start|>assistant" in solution_str: + # solution_str = solution_str.split("<|im_start|>assistant", 1)[1] + # else: + # return None + # solution_str = solution_str.split('\n')[-1] + + answer_pattern = r"(.*?)" + match = re.finditer(answer_pattern, solution_str, re.DOTALL) + matches = list(match) + + # If there are 0 matches, return None + if len(matches) < 1: + return None + + # If there are 2 or more matches, return the last one + return matches[-1].group(1).strip() + + +def count_answer_tags(text): + opening_tags = text.count("") + closing_tags = text.count("") + + return opening_tags, closing_tags + + +def compute_score(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0): + """The scoring function for exact match (EM). + + Args: + solution_str: the solution text + ground_truth: the ground truth + method: the method to extract the solution, choices are 'strict' and 'flexible' + format_score: the score for the format + score: the score for the correct answer + """ + answer = extract_solution(solution_str=solution_str) + open_count, close_count = count_answer_tags(solution_str) + do_print = random.randint(1, 64) == 1 + + if do_print: + print("--------------------------------") + print(f"Golden answers: {ground_truth['target']}") + if answer is not None: + print(f"Extracted answer is not None: {answer}") + else: + print("Extracted answer: None!") + print(f"Solution string: {solution_str}") + + if answer is None: + return 0 + else: + if em_check(answer, ground_truth["target"]): + if open_count > 10 or close_count > 10: # prevent output a lot of + score = score / 4 + return score + return score + else: + return format_score + + +def compute_score_subem(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0): + """The scoring function for substring exact match (EM). + + Args: + solution_str: the solution text + ground_truth: the ground truth + method: the method to extract the solution, choices are 'strict' and 'flexible' + format_score: the score for the format + score: the score for the correct answer + """ + answer = extract_solution(solution_str=solution_str) + do_print = random.randint(1, 64) == 1 + + if do_print: + print("--------------------------------") + print(f"Golden answers: {ground_truth['target']}") + print(f"Extracted answer: {answer}") + print(f"Solution string: {solution_str}") + + if answer is None: + return 0 + else: + if subem_check(answer, ground_truth["target"]): + return score + else: + return format_score