-
Notifications
You must be signed in to change notification settings - Fork 2.6k
[sglang] Feat: Search Tool Invocation in Multi-Turn RL Training #1682
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 22 commits
7798a85
078dac8
005d1f4
e4948db
10cf504
8fb2248
82c3877
3b58287
df7e9f2
32a02aa
482c34a
199f5ec
d4db488
e65ffaf
9bc1a11
802a30e
295192b
de45edc
46bdb2d
81fa4f3
a0ba37c
2ebcc8c
bb08d52
0c05baf
72cf23a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| ======================= | ||
| Search Tool Integration | ||
| ======================= | ||
| Introduction | ||
| ------------ | ||
| - We have added **search-tool** invocation capability to **verl-sglang MultiTurnRL**, enabling the model to issue retrieval requests during Actor rollout and directly leverage the returned results for training. | ||
|
|
||
|
|
||
|
|
||
| How to Use | ||
| ---------- | ||
| Refer to `verl-multiturn-searchR1-like.md <https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/tool_examples/verl-multiturn-searchR1-like.md>`_ or `verl-multiturn-searchR1-like_ZH.md <https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/tool_examples/verl-multiturn-searchR1-like_ZH.md>`_ in the `Awesome-ML-SYS-Tutorial repository <https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial>`_. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,168 @@ | ||
| # Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
| # Copyright 2023-2024 SGLang Team | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import argparse | ||
| import logging | ||
| import os | ||
| import tempfile | ||
|
|
||
| import pandas as pd | ||
| from huggingface_hub import hf_hub_download | ||
| from huggingface_hub.utils import EntryNotFoundError | ||
|
|
||
| from verl.utils.hdfs_io import copy, makedirs | ||
|
|
||
| # Setup logging | ||
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # Configuration constants | ||
| DEFAULT_SYSTEM_CONTENT = "You are a helpful and harmless assistant." | ||
| DEFAULT_USER_CONTENT_PREFIX = ( | ||
| "Answer the given question. You must conduct reasoning inside <think> and </think> " | ||
| "first every time you get new information. After reasoning, if you find you lack " | ||
| "some knowledge, you can call a search engine by <tool_call> query </tool_call> " | ||
| "and it will return the top searched results between <tool_response> and " | ||
| "</tool_response>. You can search as many times as your want. If you find no " | ||
| "further external knowledge needed, you can directly provide the answer inside " | ||
| "<answer> and </answer>, without detailed illustrations. For example, " | ||
| "<answer> Beijing </answer>. Question: " | ||
| ) | ||
|
|
||
|
|
||
| def process_single_row(row, current_split_name, row_index): | ||
| """ | ||
| Process a single row of data for SearchR1-like format. | ||
|
|
||
| Args: | ||
| row: DataFrame row containing the original data | ||
| current_split_name: Name of the current split (train/test) | ||
| row_index: Index of the row in the DataFrame | ||
|
|
||
| Returns: | ||
| pd.Series: Processed row data in the required format | ||
| """ | ||
| question = row.get("question", "") | ||
|
|
||
| # Build prompt structure | ||
| user_content = user_content_prefix.rstrip("\n") + question | ||
| prompt = [{"role": "system", "content": system_content}, {"role": "user", "content": user_content}] | ||
|
|
||
| # Extract ground truth from reward_model or fallback to golden_answers | ||
| reward_model_data = row.get("reward_model") | ||
| if isinstance(reward_model_data, dict) and "ground_truth" in reward_model_data: | ||
| ground_truth = reward_model_data.get("ground_truth") | ||
| else: | ||
| ground_truth = row.get("golden_answers", []) | ||
|
|
||
| # Process data source | ||
| data_source_tagged = "searchR1_" + str(row.get("data_source", "")) | ||
|
|
||
| # Build tools kwargs structure | ||
| tools_kwargs = {"search": {"create_kwargs": {"ground_truth": ground_truth, "question": question, "data_source": data_source_tagged}}} | ||
|
|
||
| # Build complete extra_info structure | ||
| extra_info = { | ||
| "index": row_index, | ||
| "need_tools_kwargs": True, | ||
| "question": question, | ||
| "split": current_split_name, | ||
| "tools_kwargs": tools_kwargs, | ||
| } | ||
|
|
||
| return pd.Series( | ||
| { | ||
| "data_source": data_source_tagged, | ||
| "prompt": prompt, | ||
| "ability": row.get("ability"), | ||
| "reward_model": reward_model_data, | ||
| "extra_info": extra_info, | ||
| "metadata": row.get("metadata"), | ||
| } | ||
| ) | ||
|
|
||
|
|
||
| def main(): | ||
| local_save_dir = os.path.expanduser(args.local_dir) | ||
| os.makedirs(local_save_dir, exist_ok=True) | ||
|
|
||
| processed_files = [] | ||
|
|
||
| # Download and process files using temporary directory | ||
| with tempfile.TemporaryDirectory() as tmp_download_dir: | ||
| for split in ["train", "test"]: | ||
| parquet_filename = f"{split}.parquet" | ||
| logger.info(f"Processing {split} split...") | ||
|
|
||
| try: | ||
| # Download Parquet file from HuggingFace | ||
| logger.info(f"Downloading {parquet_filename} from {args.hf_repo_id}") | ||
| local_parquet_filepath = hf_hub_download( | ||
| repo_id=args.hf_repo_id, | ||
| filename=parquet_filename, | ||
| repo_type="dataset", | ||
| local_dir=tmp_download_dir, | ||
| local_dir_use_symlinks=False, | ||
| ) | ||
|
|
||
| # Load and process Parquet file | ||
| df_raw = pd.read_parquet(local_parquet_filepath) | ||
| logger.info(f"Loaded {len(df_raw)} rows from {parquet_filename}") | ||
|
|
||
| def apply_process_row(row, split_name=split): | ||
| return process_single_row(row, current_split_name=split_name, row_index=row.name) | ||
|
|
||
| df_processed = df_raw.apply(apply_process_row, axis=1) | ||
|
|
||
| # Save processed DataFrame | ||
| output_file_path = os.path.join(local_save_dir, f"{split}.parquet") | ||
| df_processed.to_parquet(output_file_path, index=False) | ||
| logger.info(f"Saved {len(df_processed)} processed rows to {output_file_path}") | ||
| processed_files.append(output_file_path) | ||
|
|
||
| except EntryNotFoundError: | ||
| logger.warning(f"{parquet_filename} not found in repository {args.hf_repo_id}") | ||
| except Exception as e: | ||
| logger.error(f"Error processing {split} split: {e}") | ||
|
|
||
| if not processed_files: | ||
| logger.warning("No data was processed or saved") | ||
| return | ||
|
|
||
| logger.info(f"Successfully processed {len(processed_files)} files to {local_save_dir}") | ||
|
|
||
| # Copy to HDFS if specified | ||
| if args.hdfs_dir: | ||
| try: | ||
| makedirs(args.hdfs_dir) | ||
| copy(src=local_save_dir, dst=args.hdfs_dir) | ||
| logger.info(f"Successfully copied files to HDFS: {args.hdfs_dir}") | ||
| except Exception as e: | ||
| logger.error(f"Error copying files to HDFS: {e}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser(description="Download Search-R1 from HuggingFace, process, and save to Parquet.") | ||
| parser.add_argument("--hf_repo_id", default="PeterJinGo/nq_hotpotqa_train", help="HuggingFace dataset repository ID.") | ||
| parser.add_argument("--local_dir", default="~/data/searchR1_processed_direct", help="Local directory to save the processed Parquet files.") | ||
| parser.add_argument("--hdfs_dir", default=None, help="Optional HDFS directory to copy the Parquet files to.") | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| # System and user content configuration | ||
| system_content = DEFAULT_SYSTEM_CONTENT | ||
| user_content_prefix = DEFAULT_USER_CONTENT_PREFIX | ||
|
|
||
| main() |
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is an empty file introduced here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn’t modify this file myself — it was likely introduced during the git merge main. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| hydra: | ||
| searchpath: | ||
| - file://verl/trainer/config | ||
|
|
||
| defaults: | ||
| - ppo_trainer | ||
| - _self_ | ||
|
|
||
| data: | ||
| max_prompt_length: 1024 | ||
| max_response_length: 1024 | ||
| train_batch_size: 256 | ||
| return_raw_chat: True | ||
| shuffle: False | ||
|
|
||
| actor_rollout_ref: | ||
| hybrid_engine: True | ||
| rollout: | ||
| name: sglang_async | ||
| multi_turn: | ||
| enable: True | ||
| max_turns: 2 | ||
| format: qwen |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| tools: | ||
| - class_name: verl.tools.search_tool.SearchTool | ||
| config: | ||
| retrieval_service_url: http://127.0.0.1:8000/retrieve | ||
| num_workers: 120 | ||
| rate_limit: 120 | ||
| timeout: 30 | ||
| tool_schema: | ||
| type: function | ||
| function: | ||
| name: search | ||
| description: Searches the web for relevant information based on the given query. | ||
| parameters: | ||
| type: object | ||
| properties: | ||
| query_list: | ||
| type: array | ||
| item: | ||
| type: string | ||
| description: A list of fully-formed semantic queries. The tool will return search results for each query. | ||
| required: | ||
| - query_list |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| # Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
| # Copyright 2023-2024 SGLang Team | ||
| # Copyright 2025 Search-R1 Contributors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/scripts/download.py | ||
|
|
||
|
|
||
| import argparse | ||
|
|
||
| from huggingface_hub import hf_hub_download | ||
|
|
||
| parser = argparse.ArgumentParser(description="Download files from a Hugging Face dataset repository.") | ||
| parser.add_argument("--repo_id", type=str, default="PeterJinGo/wiki-18-e5-index", help="Hugging Face repository ID") | ||
| parser.add_argument("--save_path", type=str, required=True, help="Local directory to save files") | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| repo_id = "PeterJinGo/wiki-18-e5-index" | ||
| for file in ["part_aa", "part_ab"]: | ||
| hf_hub_download( | ||
| repo_id=repo_id, | ||
| filename=file, # e.g., "e5_Flat.index" | ||
| repo_type="dataset", | ||
| local_dir=args.save_path, | ||
| ) | ||
|
|
||
| repo_id = "PeterJinGo/wiki-18-corpus" | ||
| hf_hub_download( | ||
| repo_id=repo_id, | ||
| filename="wiki-18.jsonl.gz", | ||
| repo_type="dataset", | ||
| local_dir=args.save_path, | ||
| ) |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have concerns about how to maintain the documentation in the long term. e.g. if the example script changes in the verl repo, while still referencing the doc in a separate repo. The best way is to colocate the documentation with code, in the same repo.
For now i'm ok with merging the changes and revisit such problems in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We’ll keep the docs up to date and find a good way to manage them!