From 95b7cfec73a0fb7d37f666448c7c3937f5e03b98 Mon Sep 17 00:00:00 2001 From: Zilong Wang Date: Tue, 29 Jul 2025 02:45:36 +0800 Subject: [PATCH 01/10] [feat] add rag agent example --- examples/rag_agent/README.md | 15 + examples/rag_agent/rag_agent.py | 78 +++++ examples/rag_agent/train.sh | 53 +++ examples/rag_agent/utils.py | 329 ++++++++++++++++++ .../wiki_retriever_install.sh | 3 + .../wiki_retriever_mcp/wiki_retriever_mcp.py | 51 +++ 6 files changed, 529 insertions(+) create mode 100644 examples/rag_agent/README.md create mode 100644 examples/rag_agent/rag_agent.py create mode 100644 examples/rag_agent/train.sh create mode 100644 examples/rag_agent/utils.py create mode 100644 examples/rag_agent/wiki_retriever_mcp/wiki_retriever_install.sh create mode 100644 examples/rag_agent/wiki_retriever_mcp/wiki_retriever_mcp.py diff --git a/examples/rag_agent/README.md b/examples/rag_agent/README.md new file mode 100644 index 00000000..35e79b1e --- /dev/null +++ b/examples/rag_agent/README.md @@ -0,0 +1,15 @@ +# RAG Agent Example + +This example originally runs on a single node with four GPUs, each requiring at least 40GB of memory. + +1. Prepare the RAG dataset in the wiki_retriever_mcp folder. Wiki chunks (`nq_list.pkl`) and Faiss index (`nq_hnsw_faiss_n32e40.index`) are required. (Full wiki dump files are huge, additional information will be provided later) +2. Prepare the training data in the `data` folder. Download from [here](https://drive.google.com/drive/folders/1hEqOY4EbplUB5ew-8UPFhV_5QU2j7WCN?usp=drive_link). `musique_train.parquet` and `musique_dev_128.parquet` are required. +3. Set up the environment for wiki retriever MCP: `bash wiki_retriever_install.sh`. This will install the required packages and set up the environment for the wiki retriever MCP. +4. Start the wiki retriever MCP: `python wiki_retriever_mcp.py`. This will start the wiki retriever MCP server. +5. Start Ray: `bash ../../scripts/restart_ray.sh`. To use Wandb, you need to set the WANDB_API_KEY environment variable before starting Ray. +6. Run the agent: `python calc_agent.py`. This automatically launches 12 agent workers by default. +7. In another terminal, launch the training server: `bash train.sh`. + +## Evaluation + +Results are coming soon. diff --git a/examples/rag_agent/rag_agent.py b/examples/rag_agent/rag_agent.py new file mode 100644 index 00000000..ab462efc --- /dev/null +++ b/examples/rag_agent/rag_agent.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import os +import re +import shutil +import sys +import tempfile +import time +from typing import Any, Literal, Optional + +import dotenv +import termcolor +from agents import (Agent, Runner, function_tool, gen_trace_id, + set_trace_processors, set_tracing_disabled, trace) +from agents.extensions.models.litellm_model import LitellmModel +from agents.mcp import MCPServer, MCPServerSse +from agents.model_settings import ModelSettings +from agents.tracing.processors import BatchTraceProcessor, ConsoleSpanExporter +from utils import compute_scores + +import agentlightning +from agentlightning import (LLM, LitAgent, NamedResources, Trainer, + configure_logger, reward) + +configure_logger() + +agent_prompt="""You are an assistant who answers questions using Wikipedia retriever. Answer the question using only the retrieved passages. Verify your answer directly against the text. + +After each search: +- Summarize findings. +- Decide if info is sufficient. + - If sufficient: reply in ... with your answer. The answer must be extremely concise: a single word or a few words only. + - If not: suggest the next search needed to fill info gaps. The system will return top 3 relevant Wikipedia chunks. +- Explain your reasoning for the chosen action. + +Repeat as needed. When done, wrap your final, concise answer in tags.""" + +class RAGAgent(LitAgent): + def __init__(self): + self.mcp_server_url = "http://127.0.0.1:8099/sse" + + async def training_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any: + llm: LLM = resources.get("main_llm") + print("Training with model:", llm.model, "on endpoint:", llm.endpoint) + async with MCPServerSse( + name="wiki_retrieval_mcp", + params={"url": self.mcp_server_url}, + ) as server: + agent = Agent( + model=LitellmModel(model='hosted_vllm/'+llm.model, base_url=llm.endpoint), + model_settings=ModelSettings( + max_tokens=4096, + temperature=0.7, + ), + name="Assistant", + instructions=agent_prompt, + mcp_servers=[server], + ) + result = await Runner.run(agent, task["question"]) + answer = result.final_output + reward = compute_scores(answer, str(task["answer"])) + print("question:{} answer: {} ground_truth: {} reward: {}".format(task["question"], answer, task["answer"], reward)) + + + async def validation_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any: + llm: LLM = resources.get("main_llm") + resources = { + "main_llm": LLM( + endpoint=llm.endpoint, + model=llm.model, + sampling_parameters={"temperature": 0.7}, + ) + } + return await self.training_rollout_async(task, rollout_id, resources) + + +if __name__ == "__main__": + Trainer(n_workers=12).fit(RAGAgent(), "http://localhost:9999/") \ No newline at end of file diff --git a/examples/rag_agent/train.sh b/examples/rag_agent/train.sh new file mode 100644 index 00000000..7615b085 --- /dev/null +++ b/examples/rag_agent/train.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +set -e + +export N_GPUS=1 +export BASE_MODEL=Qwen/Qwen3-1.7B +export DATA_DIR=data +export ROLLOUT_TP_SIZE=1 +export EXPERIMENT_NAME=rag_agent +export PROJECT_NAME=AgentLightning + +echo "Starting training script..." + +python -m agentlightning.verl \ + algorithm.adv_estimator=grpo \ + data.train_files=${DATA_DIR}/musique_train.parquet \ + data.val_files=${DATA_DIR}/musique_dev_128.parquet \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \ + trainer.n_gpus_per_node=${N_GPUS} \ + data.train_batch_size=32 \ + actor_rollout_ref.rollout.n=4 \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.multi_turn.format=hermes \ + actor_rollout_ref.model.path=${BASE_MODEL} \ + data.max_prompt_length=4096 \ + data.max_response_length=2048 \ + data.truncation='error' \ + trainer.val_before_train=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0.000 \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.clip_ratio_low=0.2 \ + actor_rollout_ref.actor.clip_ratio_high=0.3 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${PROJECT_NAME} \ + trainer.experiment_name=${EXPERIMENT_NAME} \ + trainer.nnodes=1 \ + trainer.save_freq=40 \ + trainer.test_freq=20 \ + trainer.total_epochs=2 $@ diff --git a/examples/rag_agent/utils.py b/examples/rag_agent/utils.py new file mode 100644 index 00000000..8ea7b267 --- /dev/null +++ b/examples/rag_agent/utils.py @@ -0,0 +1,329 @@ +#import ujson as json +import json +import pickle +import re +import string +import sys +from collections import Counter + +ANS_BEGIN = "" +ANS_END = "" +GEN_BEGIN = "<|im_start|>assistant\n" +FORMAT_SCORE = 0.1 +FORMAT_PUNISH = -2 + +def normalize_answer(s): + + def remove_articles(text): + return re.sub(r'\b(a|an|the)\b', ' ', text) + + def white_space_fix(text): + return ' '.join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return ''.join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(prediction, ground_truth): + normalized_prediction = normalize_answer(prediction) + normalized_ground_truth = normalize_answer(ground_truth) + + ZERO_METRIC = (0, 0, 0) + + if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: + return ZERO_METRIC + if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: + return ZERO_METRIC + + prediction_tokens = normalized_prediction.split() + ground_truth_tokens = normalized_ground_truth.split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return ZERO_METRIC + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1, precision, recall + +def lenient_f1_score(prediction, ground_truth): + normalized_prediction = normalize_answer(prediction) + normalized_ground_truth = normalize_answer(ground_truth) + + ZERO_METRIC = (0, 0, 0) + + if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: + if normalized_ground_truth == 'yes' and ('no' in normalized_prediction or 'noanswer' in normalized_prediction): + return ZERO_METRIC + if normalized_ground_truth == 'no' and ('yes' in normalized_prediction or 'noanswer' in normalized_prediction): + return ZERO_METRIC + + prediction_tokens = normalized_prediction.split() + ground_truth_tokens = normalized_ground_truth.split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return ZERO_METRIC + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1, precision, recall + + +def exact_match_score(prediction, ground_truth): + return (normalize_answer(prediction) == normalize_answer(ground_truth)) + +def cover_exact_match_score(prediction, ground_truth): + return (normalize_answer(ground_truth) in normalize_answer(prediction)) + +def extract_answer(response): + if ANS_BEGIN not in response or ANS_END not in response: + return '' + pos1 = response.rfind(ANS_BEGIN) + pos2 = response.rfind(ANS_END) + assert pos2 !=-1 + if pos1 != -1: + ans = response[pos1 + len(ANS_BEGIN):pos2] + else: + ans = response[len(ANS_BEGIN):pos2] + return ans + +def split_response(text): + start_response = text.rfind(GEN_BEGIN) + response = text[start_response + len(GEN_BEGIN):] + prompt = text[:-len(response)] + return prompt, response + +def extract_recall_chunk(prompt, response): + import re + + # 正则表达式,匹配每个search_step内1.和2.后面的内容 + pattern = r"Retrieved sentences:\s*1\.\s*(.*?)\s*2\.\s*(.*?)(?:\n\s*\d+\.|\n\n|$)" + + # 使用re.findall 提取所有的(s1, s2) + origin_recall = re.findall(pattern, prompt, re.DOTALL) + sequential_recall = re.findall(pattern, response, re.DOTALL) + origin_recall_set = set(s for pair in origin_recall for s in pair) + sequential_recall_set = set(s for pair in sequential_recall for s in pair) + + return origin_recall_set, sequential_recall_set + +import re + + +def extract_retrieved_paragraphs(log_text): + # 正则表达式匹配 "Retrieved paragraph:" 后的内容 + pattern = re.compile(r"Retrieved paragraph:\s*(.*?)\n", re.DOTALL) + + # 提取匹配的段落 + matches = pattern.findall(log_text) + matches = list(set(matches)) + return matches + +def compute_score(prediction, gold, gold_sentences=None, data_source=None): + # format acc + format_acc = FORMAT_SCORE + + + prompt, response = split_response(prediction) + ans = extract_answer(response) + if ans == '': + # format score 0.1 + # if '' not in response or '' not in response: + # return 0.0 + # return 0.0 + delimiter = "<|im_start|>assistant" + last_time_ans = response.split(delimiter)[-1] + if '' not in last_time_ans: + return 0.0 + return format_acc + + # answer acc + em, cem = exact_match_score(ans, gold), cover_exact_match_score(ans, gold) + f1, prec, recall = f1_score(ans, gold) + + + if fact_checking_api(prediction, ans): + answer_acc = max(float(em), f1) + else: + answer_acc = 0 + # # search acc + # if gold_sentences and search_weight: + # origin_recall_set, sequential_recall_set = extract_recall_chunk(prompt, response) + # gold_sentences_set = set(gold_sentences) - origin_recall_set + # matched = gold_sentences_set & sequential_recall_set + # search_acc = len(matched) / len(gold_sentences_set) if len(gold_sentences_set) != 0 else 1.0 + # # print(f's_acc {search_acc}|a_acc {answer_acc=}| score {format_acc + (1 - format_acc) * (search_weight + (1 - search_weight) * answer_acc)} |m_len {len(matched)}|g_len {len(gold_sentences_set)}|o_len {len(origin_recall_set)}|s_len {len(sequential_recall_set)}|{gold_sentences_set}|{sequential_recall_set}') + # if search_acc < 1: + # return format_acc + (1 - format_acc) * search_weight * search_acc + # # print(f'SCORE: {score} | {ans} | {gold} | {prediction}' ) + + return format_acc + (1 - format_acc) * answer_acc + # return answer_acc + +def compute_reward(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): + prediction = solution_str + gold = ground_truth + return compute_score(prediction, gold, gold_sentences=gold_sentences, data_source=data_source) + +def compute_em(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): + prediction = solution_str + gold = ground_truth + prompt, response = split_response(prediction) + ans = extract_answer(response) + if ans == '': + # format score 0.1 + # if '' not in response or '' not in response: + # return 0.0 + return 0.0 + + # answer acc + em = exact_match_score(ans, gold) + return em + +def compute_cem(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): + prediction = solution_str + gold = ground_truth + prompt, response = split_response(prediction) + ans = extract_answer(response) + if ans == '': + return 0.0 + + # answer acc + cem = cover_exact_match_score(ans, gold) + return cem + + +def compute_response_cem(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): + prediction = solution_str + gold = ground_truth + prompt, response = split_response(prediction) + ans = response + if ans == '': + return 0.0 + + # answer acc + cem = cover_exact_match_score(ans, gold) + return cem + +def compute_lenient_f1(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): + prediction = solution_str + gold = ground_truth + prompt, response = split_response(prediction) + ans = extract_answer(response) + if ans == '': + return 0.0 + + # answer acc + f1, prec, recall = lenient_f1_score(ans, gold) + return f1 + +def compute_lenient_response_f1(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): + prediction = solution_str + gold = ground_truth + prompt, response = split_response(prediction) + ans = response + if ans == '': + return 0.0 + + # answer acc + f1, prec, recall = lenient_f1_score(ans, gold) + return f1 + + + +def compute_f1(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): + prediction = solution_str + gold = ground_truth + prompt, response = split_response(prediction) + ans = extract_answer(response) + if ans == '': + return 0.0 + + # answer acc + f1, prec, recall = f1_score(ans, gold) + return f1 + +def compute_format(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): + prediction = solution_str + gold = ground_truth + prompt, response = split_response(prediction) + ans = extract_answer(response) + if ans == '': + delimiter = "<|im_start|>assistant" + last_time_ans = response.split(delimiter)[-1] + if '' not in last_time_ans: + return 0 + return FORMAT_SCORE + +def split_trace(text): + start_response = text.find(GEN_BEGIN) + response = text[start_response + len(GEN_BEGIN):] + prompt = text[:-len(response)] + return prompt, response + +def compute_action_query(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): + prediction = solution_str + gold = ground_truth + prompt, trace = split_trace(prediction) + res = min(trace.count('') + trace.count('')) + return res + +def compute_action_bm25(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): + prediction = solution_str + gold = ground_truth + prompt, trace = split_trace(prediction) + res = min(trace.count('')) + return res + +def compute_action_read_pre(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): + prediction = solution_str + gold = ground_truth + prompt, trace = split_trace(prediction) + res = min(trace.count('')) + return res + +def compute_action_read_nxt(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): + prediction = solution_str + gold = ground_truth + prompt, trace = split_trace(prediction) + res = min(trace.count('')) + return res + +def compute_action_continue(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): + prediction = solution_str + gold = ground_truth + prompt, trace = split_trace(prediction) + res = min(trace.count(', continue'), trace.count('')) + return res + +def compute_action_match(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): + prediction = solution_str + gold = ground_truth + prompt, trace = split_trace(prediction) + res = min(trace.count(', match_phrase="'), trace.count('')) + return res + +def compute_total_action_number(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): + prediction = solution_str + gold = ground_truth + prompt, trace = split_trace(prediction) + res = min(trace.count('')) + return res + +# define reward functions for evaluation + +def compute_scores(answer, ground_truth): + parsed_answer = extract_answer(answer) + if parsed_answer is None: + return -0.1 + f1, precision, recall = f1_score(parsed_answer, ground_truth) + # em = float(exact_match_score(parsed_answer, ground_truth)) + # cem = float(cover_exact_match_score(answer, ground_truth)) + return f1 \ No newline at end of file diff --git a/examples/rag_agent/wiki_retriever_mcp/wiki_retriever_install.sh b/examples/rag_agent/wiki_retriever_mcp/wiki_retriever_install.sh new file mode 100644 index 00000000..7b22cdfc --- /dev/null +++ b/examples/rag_agent/wiki_retriever_mcp/wiki_retriever_install.sh @@ -0,0 +1,3 @@ +conda create -n mcp_server python=3.12 -y +conda activate mcp_server +pip install faiss-cpu==1.11.0 fastmcp==2.5.1 sentence-transformers==4.1.0 \ No newline at end of file diff --git a/examples/rag_agent/wiki_retriever_mcp/wiki_retriever_mcp.py b/examples/rag_agent/wiki_retriever_mcp/wiki_retriever_mcp.py new file mode 100644 index 00000000..12a94840 --- /dev/null +++ b/examples/rag_agent/wiki_retriever_mcp/wiki_retriever_mcp.py @@ -0,0 +1,51 @@ +import faiss +from sentence_transformers import SentenceTransformer +import pickle +from fastmcp import FastMCP + +#index = faiss.read_index("/mnt/input/agent_lightning/nq_hnsw_faiss_n32e40.index") +index = faiss.read_index("nq_hnsw_faiss_n32e40.index") +print("Index loaded successfully.") + +model = SentenceTransformer("BAAI/bge-large-en-v1.5") +print("Model loaded successfully.") + +#with open('/mnt/input/agent_lightning/nq_list.pkl', 'rb') as f: +with open('nq_list.pkl', 'rb') as f: + chunks = pickle.load(f) +print("Chunks loaded successfully.") + +mcp = FastMCP(name="wiki retrievel mcp") + +@mcp.tool( + name="retrieve", + description="retrieve relevant chunks from the wikipedia", +) +def retrieve(query: str) -> list: + """ + Retrieve relevant chunks from the Wikipedia dataset. + + Args: + query (str): The query string to search for. + + Returns: + list: A list of dictionaries containing the retrieved chunks and their metadata. + """ + top_k = 4 # Number of top results to return + embedding = model.encode([query], normalize_embeddings=True) + D, I = index.search(embedding, top_k) + + results = [] + for i in range(top_k): + if I[0][i] != -1: + chunk = chunks[I[0][i]] + results.append({ + "chunk": chunk, + "chunk_id": I[0][i], + "distance": D[0][i] + }) + return results + + +if __name__ == "__main__": + mcp.run(transport="sse",host="127.0.0.1", port=8099) From 70f1547ea3ed479970e9ff82425096031042094e Mon Sep 17 00:00:00 2001 From: Wang Zilong Date: Tue, 29 Jul 2025 02:47:27 +0800 Subject: [PATCH 02/10] Update examples/rag_agent/wiki_retriever_mcp/wiki_retriever_mcp.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- examples/rag_agent/wiki_retriever_mcp/wiki_retriever_mcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/rag_agent/wiki_retriever_mcp/wiki_retriever_mcp.py b/examples/rag_agent/wiki_retriever_mcp/wiki_retriever_mcp.py index 12a94840..e424f2e6 100644 --- a/examples/rag_agent/wiki_retriever_mcp/wiki_retriever_mcp.py +++ b/examples/rag_agent/wiki_retriever_mcp/wiki_retriever_mcp.py @@ -15,7 +15,7 @@ chunks = pickle.load(f) print("Chunks loaded successfully.") -mcp = FastMCP(name="wiki retrievel mcp") +mcp = FastMCP(name="wiki retrieval mcp") @mcp.tool( name="retrieve", From 4b1ba9a2c498721b7732cefd7d125dffe799f7c6 Mon Sep 17 00:00:00 2001 From: Wang Zilong Date: Tue, 29 Jul 2025 02:47:40 +0800 Subject: [PATCH 03/10] Update examples/rag_agent/rag_agent.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- examples/rag_agent/rag_agent.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/rag_agent/rag_agent.py b/examples/rag_agent/rag_agent.py index ab462efc..b3cd6829 100644 --- a/examples/rag_agent/rag_agent.py +++ b/examples/rag_agent/rag_agent.py @@ -60,8 +60,7 @@ async def training_rollout_async(self, task: Any, rollout_id: str, resources: Na answer = result.final_output reward = compute_scores(answer, str(task["answer"])) print("question:{} answer: {} ground_truth: {} reward: {}".format(task["question"], answer, task["answer"], reward)) - - + return reward async def validation_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any: llm: LLM = resources.get("main_llm") resources = { From fd88f6ec7e4a4fc88f2211c0588ca6a6f363b756 Mon Sep 17 00:00:00 2001 From: Wang Zilong Date: Tue, 29 Jul 2025 02:48:06 +0800 Subject: [PATCH 04/10] Update examples/rag_agent/README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- examples/rag_agent/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/rag_agent/README.md b/examples/rag_agent/README.md index 35e79b1e..19119635 100644 --- a/examples/rag_agent/README.md +++ b/examples/rag_agent/README.md @@ -7,7 +7,7 @@ This example originally runs on a single node with four GPUs, each requiring at 3. Set up the environment for wiki retriever MCP: `bash wiki_retriever_install.sh`. This will install the required packages and set up the environment for the wiki retriever MCP. 4. Start the wiki retriever MCP: `python wiki_retriever_mcp.py`. This will start the wiki retriever MCP server. 5. Start Ray: `bash ../../scripts/restart_ray.sh`. To use Wandb, you need to set the WANDB_API_KEY environment variable before starting Ray. -6. Run the agent: `python calc_agent.py`. This automatically launches 12 agent workers by default. +6. Run the agent: `python rag_agent.py`. This automatically launches 12 agent workers by default. 7. In another terminal, launch the training server: `bash train.sh`. ## Evaluation From f7bf8a97e8b3ee8bd5c3fe1f17b96609a81837dc Mon Sep 17 00:00:00 2001 From: Wang Zilong Date: Tue, 29 Jul 2025 02:48:21 +0800 Subject: [PATCH 05/10] Update examples/rag_agent/utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- examples/rag_agent/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/rag_agent/utils.py b/examples/rag_agent/utils.py index 8ea7b267..8d5421d9 100644 --- a/examples/rag_agent/utils.py +++ b/examples/rag_agent/utils.py @@ -1,4 +1,3 @@ -#import ujson as json import json import pickle import re From b51a7449188fb5bf853038dc1bb6dfcd0144c957 Mon Sep 17 00:00:00 2001 From: Zilong Wang Date: Tue, 29 Jul 2025 02:51:46 +0800 Subject: [PATCH 06/10] [fix] rag agent --- examples/rag_agent/README.md | 2 +- examples/rag_agent/rag_agent.py | 3 ++- examples/rag_agent/utils.py | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/rag_agent/README.md b/examples/rag_agent/README.md index 35e79b1e..19119635 100644 --- a/examples/rag_agent/README.md +++ b/examples/rag_agent/README.md @@ -7,7 +7,7 @@ This example originally runs on a single node with four GPUs, each requiring at 3. Set up the environment for wiki retriever MCP: `bash wiki_retriever_install.sh`. This will install the required packages and set up the environment for the wiki retriever MCP. 4. Start the wiki retriever MCP: `python wiki_retriever_mcp.py`. This will start the wiki retriever MCP server. 5. Start Ray: `bash ../../scripts/restart_ray.sh`. To use Wandb, you need to set the WANDB_API_KEY environment variable before starting Ray. -6. Run the agent: `python calc_agent.py`. This automatically launches 12 agent workers by default. +6. Run the agent: `python rag_agent.py`. This automatically launches 12 agent workers by default. 7. In another terminal, launch the training server: `bash train.sh`. ## Evaluation diff --git a/examples/rag_agent/rag_agent.py b/examples/rag_agent/rag_agent.py index ab462efc..d23377a8 100644 --- a/examples/rag_agent/rag_agent.py +++ b/examples/rag_agent/rag_agent.py @@ -43,7 +43,7 @@ async def training_rollout_async(self, task: Any, rollout_id: str, resources: Na llm: LLM = resources.get("main_llm") print("Training with model:", llm.model, "on endpoint:", llm.endpoint) async with MCPServerSse( - name="wiki_retrieval_mcp", + name="wiki_retriever_mcp", params={"url": self.mcp_server_url}, ) as server: agent = Agent( @@ -60,6 +60,7 @@ async def training_rollout_async(self, task: Any, rollout_id: str, resources: Na answer = result.final_output reward = compute_scores(answer, str(task["answer"])) print("question:{} answer: {} ground_truth: {} reward: {}".format(task["question"], answer, task["answer"], reward)) + return reward async def validation_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any: diff --git a/examples/rag_agent/utils.py b/examples/rag_agent/utils.py index 8ea7b267..f6860d30 100644 --- a/examples/rag_agent/utils.py +++ b/examples/rag_agent/utils.py @@ -236,7 +236,8 @@ def compute_lenient_response_f1(solution_str=None, ground_truth=None, gold_sente f1, prec, recall = lenient_f1_score(ans, gold) return f1 - +def fact_checking_api(prediction, ans): + return True # Placeholder for actual fact-checking logic def compute_f1(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): prediction = solution_str From bc30d2652904ff03079b68215ca8d395e01778b3 Mon Sep 17 00:00:00 2001 From: Zilong Wang Date: Tue, 29 Jul 2025 02:55:47 +0800 Subject: [PATCH 07/10] [fmt] black formatter --- examples/rag_agent/rag_agent.py | 16 ++- examples/rag_agent/utils.py | 116 +++++++++++------- .../wiki_retriever_mcp/wiki_retriever_mcp.py | 23 ++-- 3 files changed, 87 insertions(+), 68 deletions(-) diff --git a/examples/rag_agent/rag_agent.py b/examples/rag_agent/rag_agent.py index d23377a8..979bbeba 100644 --- a/examples/rag_agent/rag_agent.py +++ b/examples/rag_agent/rag_agent.py @@ -10,8 +10,7 @@ import dotenv import termcolor -from agents import (Agent, Runner, function_tool, gen_trace_id, - set_trace_processors, set_tracing_disabled, trace) +from agents import Agent, Runner, function_tool, gen_trace_id, set_trace_processors, set_tracing_disabled, trace from agents.extensions.models.litellm_model import LitellmModel from agents.mcp import MCPServer, MCPServerSse from agents.model_settings import ModelSettings @@ -19,12 +18,11 @@ from utils import compute_scores import agentlightning -from agentlightning import (LLM, LitAgent, NamedResources, Trainer, - configure_logger, reward) +from agentlightning import LLM, LitAgent, NamedResources, Trainer, configure_logger, reward configure_logger() -agent_prompt="""You are an assistant who answers questions using Wikipedia retriever. Answer the question using only the retrieved passages. Verify your answer directly against the text. +agent_prompt = """You are an assistant who answers questions using Wikipedia retriever. Answer the question using only the retrieved passages. Verify your answer directly against the text. After each search: - Summarize findings. @@ -34,7 +32,8 @@ - Explain your reasoning for the chosen action. Repeat as needed. When done, wrap your final, concise answer in tags.""" - + + class RAGAgent(LitAgent): def __init__(self): self.mcp_server_url = "http://127.0.0.1:8099/sse" @@ -47,7 +46,7 @@ async def training_rollout_async(self, task: Any, rollout_id: str, resources: Na params={"url": self.mcp_server_url}, ) as server: agent = Agent( - model=LitellmModel(model='hosted_vllm/'+llm.model, base_url=llm.endpoint), + model=LitellmModel(model="hosted_vllm/" + llm.model, base_url=llm.endpoint), model_settings=ModelSettings( max_tokens=4096, temperature=0.7, @@ -62,7 +61,6 @@ async def training_rollout_async(self, task: Any, rollout_id: str, resources: Na print("question:{} answer: {} ground_truth: {} reward: {}".format(task["question"], answer, task["answer"], reward)) return reward - async def validation_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any: llm: LLM = resources.get("main_llm") resources = { @@ -76,4 +74,4 @@ async def validation_rollout_async(self, task: Any, rollout_id: str, resources: if __name__ == "__main__": - Trainer(n_workers=12).fit(RAGAgent(), "http://localhost:9999/") \ No newline at end of file + Trainer(n_workers=12).fit(RAGAgent(), "http://localhost:9999/") diff --git a/examples/rag_agent/utils.py b/examples/rag_agent/utils.py index 4cdc67ce..2210e1ec 100644 --- a/examples/rag_agent/utils.py +++ b/examples/rag_agent/utils.py @@ -11,17 +11,18 @@ FORMAT_SCORE = 0.1 FORMAT_PUNISH = -2 + def normalize_answer(s): def remove_articles(text): - return re.sub(r'\b(a|an|the)\b', ' ', text) + return re.sub(r"\b(a|an|the)\b", " ", text) def white_space_fix(text): - return ' '.join(text.split()) + return " ".join(text.split()) def remove_punc(text): exclude = set(string.punctuation) - return ''.join(ch for ch in text if ch not in exclude) + return "".join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() @@ -35,9 +36,9 @@ def f1_score(prediction, ground_truth): ZERO_METRIC = (0, 0, 0) - if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: + if normalized_prediction in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth: return ZERO_METRIC - if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: + if normalized_ground_truth in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth: return ZERO_METRIC prediction_tokens = normalized_prediction.split() @@ -51,16 +52,17 @@ def f1_score(prediction, ground_truth): f1 = (2 * precision * recall) / (precision + recall) return f1, precision, recall + def lenient_f1_score(prediction, ground_truth): normalized_prediction = normalize_answer(prediction) normalized_ground_truth = normalize_answer(ground_truth) ZERO_METRIC = (0, 0, 0) - if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: - if normalized_ground_truth == 'yes' and ('no' in normalized_prediction or 'noanswer' in normalized_prediction): + if normalized_ground_truth in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth: + if normalized_ground_truth == "yes" and ("no" in normalized_prediction or "noanswer" in normalized_prediction): return ZERO_METRIC - if normalized_ground_truth == 'no' and ('yes' in normalized_prediction or 'noanswer' in normalized_prediction): + if normalized_ground_truth == "no" and ("yes" in normalized_prediction or "noanswer" in normalized_prediction): return ZERO_METRIC prediction_tokens = normalized_prediction.split() @@ -76,29 +78,33 @@ def lenient_f1_score(prediction, ground_truth): def exact_match_score(prediction, ground_truth): - return (normalize_answer(prediction) == normalize_answer(ground_truth)) + return normalize_answer(prediction) == normalize_answer(ground_truth) + def cover_exact_match_score(prediction, ground_truth): - return (normalize_answer(ground_truth) in normalize_answer(prediction)) + return normalize_answer(ground_truth) in normalize_answer(prediction) + def extract_answer(response): if ANS_BEGIN not in response or ANS_END not in response: - return '' + return "" pos1 = response.rfind(ANS_BEGIN) pos2 = response.rfind(ANS_END) - assert pos2 !=-1 + assert pos2 != -1 if pos1 != -1: - ans = response[pos1 + len(ANS_BEGIN):pos2] + ans = response[pos1 + len(ANS_BEGIN) : pos2] else: - ans = response[len(ANS_BEGIN):pos2] + ans = response[len(ANS_BEGIN) : pos2] return ans + def split_response(text): start_response = text.rfind(GEN_BEGIN) - response = text[start_response + len(GEN_BEGIN):] - prompt = text[:-len(response)] + response = text[start_response + len(GEN_BEGIN) :] + prompt = text[: -len(response)] return prompt, response + def extract_recall_chunk(prompt, response): import re @@ -110,44 +116,44 @@ def extract_recall_chunk(prompt, response): sequential_recall = re.findall(pattern, response, re.DOTALL) origin_recall_set = set(s for pair in origin_recall for s in pair) sequential_recall_set = set(s for pair in sequential_recall for s in pair) - + return origin_recall_set, sequential_recall_set + import re def extract_retrieved_paragraphs(log_text): # 正则表达式匹配 "Retrieved paragraph:" 后的内容 pattern = re.compile(r"Retrieved paragraph:\s*(.*?)\n", re.DOTALL) - + # 提取匹配的段落 matches = pattern.findall(log_text) matches = list(set(matches)) return matches + def compute_score(prediction, gold, gold_sentences=None, data_source=None): # format acc format_acc = FORMAT_SCORE - - + prompt, response = split_response(prediction) ans = extract_answer(response) - if ans == '': + if ans == "": # format score 0.1 # if '' not in response or '' not in response: # return 0.0 # return 0.0 delimiter = "<|im_start|>assistant" last_time_ans = response.split(delimiter)[-1] - if '' not in last_time_ans: + if "" not in last_time_ans: return 0.0 return format_acc # answer acc em, cem = exact_match_score(ans, gold), cover_exact_match_score(ans, gold) f1, prec, recall = f1_score(ans, gold) - - + if fact_checking_api(prediction, ans): answer_acc = max(float(em), f1) else: @@ -162,40 +168,43 @@ def compute_score(prediction, gold, gold_sentences=None, data_source=None): # if search_acc < 1: # return format_acc + (1 - format_acc) * search_weight * search_acc # # print(f'SCORE: {score} | {ans} | {gold} | {prediction}' ) - - return format_acc + (1 - format_acc) * answer_acc + + return format_acc + (1 - format_acc) * answer_acc # return answer_acc + def compute_reward(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): prediction = solution_str gold = ground_truth return compute_score(prediction, gold, gold_sentences=gold_sentences, data_source=data_source) + def compute_em(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): prediction = solution_str gold = ground_truth prompt, response = split_response(prediction) ans = extract_answer(response) - if ans == '': + if ans == "": # format score 0.1 # if '' not in response or '' not in response: # return 0.0 return 0.0 # answer acc - em = exact_match_score(ans, gold) + em = exact_match_score(ans, gold) return em + def compute_cem(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): prediction = solution_str gold = ground_truth prompt, response = split_response(prediction) ans = extract_answer(response) - if ans == '': + if ans == "": return 0.0 # answer acc - cem = cover_exact_match_score(ans, gold) + cem = cover_exact_match_score(ans, gold) return cem @@ -204,121 +213,136 @@ def compute_response_cem(solution_str=None, ground_truth=None, gold_sentences=No gold = ground_truth prompt, response = split_response(prediction) ans = response - if ans == '': + if ans == "": return 0.0 # answer acc cem = cover_exact_match_score(ans, gold) return cem + def compute_lenient_f1(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): prediction = solution_str gold = ground_truth prompt, response = split_response(prediction) ans = extract_answer(response) - if ans == '': + if ans == "": return 0.0 # answer acc f1, prec, recall = lenient_f1_score(ans, gold) return f1 + def compute_lenient_response_f1(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): prediction = solution_str gold = ground_truth prompt, response = split_response(prediction) ans = response - if ans == '': + if ans == "": return 0.0 # answer acc f1, prec, recall = lenient_f1_score(ans, gold) return f1 + def fact_checking_api(prediction, ans): return True # Placeholder for actual fact-checking logic + def compute_f1(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): prediction = solution_str gold = ground_truth prompt, response = split_response(prediction) ans = extract_answer(response) - if ans == '': + if ans == "": return 0.0 # answer acc f1, prec, recall = f1_score(ans, gold) return f1 + def compute_format(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): prediction = solution_str gold = ground_truth prompt, response = split_response(prediction) ans = extract_answer(response) - if ans == '': + if ans == "": delimiter = "<|im_start|>assistant" last_time_ans = response.split(delimiter)[-1] - if '' not in last_time_ans: + if "" not in last_time_ans: return 0 return FORMAT_SCORE + def split_trace(text): start_response = text.find(GEN_BEGIN) - response = text[start_response + len(GEN_BEGIN):] - prompt = text[:-len(response)] + response = text[start_response + len(GEN_BEGIN) :] + prompt = text[: -len(response)] return prompt, response + def compute_action_query(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): prediction = solution_str gold = ground_truth prompt, trace = split_trace(prediction) - res = min(trace.count('') + trace.count('')) + res = min(trace.count("") + trace.count("")) return res + def compute_action_bm25(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): prediction = solution_str gold = ground_truth prompt, trace = split_trace(prediction) - res = min(trace.count('')) + res = min(trace.count("")) return res + def compute_action_read_pre(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): prediction = solution_str gold = ground_truth prompt, trace = split_trace(prediction) - res = min(trace.count('')) + res = min(trace.count("")) return res + def compute_action_read_nxt(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): prediction = solution_str gold = ground_truth prompt, trace = split_trace(prediction) - res = min(trace.count('')) + res = min(trace.count("")) return res + def compute_action_continue(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): prediction = solution_str gold = ground_truth prompt, trace = split_trace(prediction) - res = min(trace.count(', continue'), trace.count('')) + res = min(trace.count(", continue"), trace.count("")) return res + def compute_action_match(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): prediction = solution_str gold = ground_truth prompt, trace = split_trace(prediction) - res = min(trace.count(', match_phrase="'), trace.count('')) + res = min(trace.count(', match_phrase="'), trace.count("")) return res + def compute_total_action_number(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): prediction = solution_str gold = ground_truth prompt, trace = split_trace(prediction) - res = min(trace.count('')) + res = min(trace.count("")) return res + # define reward functions for evaluation + def compute_scores(answer, ground_truth): parsed_answer = extract_answer(answer) if parsed_answer is None: @@ -326,4 +350,4 @@ def compute_scores(answer, ground_truth): f1, precision, recall = f1_score(parsed_answer, ground_truth) # em = float(exact_match_score(parsed_answer, ground_truth)) # cem = float(cover_exact_match_score(answer, ground_truth)) - return f1 \ No newline at end of file + return f1 diff --git a/examples/rag_agent/wiki_retriever_mcp/wiki_retriever_mcp.py b/examples/rag_agent/wiki_retriever_mcp/wiki_retriever_mcp.py index e424f2e6..5affb65c 100644 --- a/examples/rag_agent/wiki_retriever_mcp/wiki_retriever_mcp.py +++ b/examples/rag_agent/wiki_retriever_mcp/wiki_retriever_mcp.py @@ -3,20 +3,21 @@ import pickle from fastmcp import FastMCP -#index = faiss.read_index("/mnt/input/agent_lightning/nq_hnsw_faiss_n32e40.index") +# index = faiss.read_index("/mnt/input/agent_lightning/nq_hnsw_faiss_n32e40.index") index = faiss.read_index("nq_hnsw_faiss_n32e40.index") print("Index loaded successfully.") model = SentenceTransformer("BAAI/bge-large-en-v1.5") print("Model loaded successfully.") -#with open('/mnt/input/agent_lightning/nq_list.pkl', 'rb') as f: -with open('nq_list.pkl', 'rb') as f: +# with open('/mnt/input/agent_lightning/nq_list.pkl', 'rb') as f: +with open("nq_list.pkl", "rb") as f: chunks = pickle.load(f) print("Chunks loaded successfully.") - + mcp = FastMCP(name="wiki retrieval mcp") + @mcp.tool( name="retrieve", description="retrieve relevant chunks from the wikipedia", @@ -24,28 +25,24 @@ def retrieve(query: str) -> list: """ Retrieve relevant chunks from the Wikipedia dataset. - + Args: query (str): The query string to search for. - + Returns: list: A list of dictionaries containing the retrieved chunks and their metadata. """ top_k = 4 # Number of top results to return embedding = model.encode([query], normalize_embeddings=True) D, I = index.search(embedding, top_k) - + results = [] for i in range(top_k): if I[0][i] != -1: chunk = chunks[I[0][i]] - results.append({ - "chunk": chunk, - "chunk_id": I[0][i], - "distance": D[0][i] - }) + results.append({"chunk": chunk, "chunk_id": I[0][i], "distance": D[0][i]}) return results if __name__ == "__main__": - mcp.run(transport="sse",host="127.0.0.1", port=8099) + mcp.run(transport="sse", host="127.0.0.1", port=8099) From 5880071e96633491be14db8c4053914e5758bea5 Mon Sep 17 00:00:00 2001 From: Zilong Wang Date: Mon, 11 Aug 2025 18:12:17 +0800 Subject: [PATCH 08/10] [fix] lint with black --- examples/rag_agent/rag_agent.py | 37 ++++++-- examples/rag_agent/utils.py | 152 +++++++++++++++++++++++++++----- 2 files changed, 159 insertions(+), 30 deletions(-) diff --git a/examples/rag_agent/rag_agent.py b/examples/rag_agent/rag_agent.py index 979bbeba..bb3243f5 100644 --- a/examples/rag_agent/rag_agent.py +++ b/examples/rag_agent/rag_agent.py @@ -10,7 +10,15 @@ import dotenv import termcolor -from agents import Agent, Runner, function_tool, gen_trace_id, set_trace_processors, set_tracing_disabled, trace +from agents import ( + Agent, + Runner, + function_tool, + gen_trace_id, + set_trace_processors, + set_tracing_disabled, + trace, +) from agents.extensions.models.litellm_model import LitellmModel from agents.mcp import MCPServer, MCPServerSse from agents.model_settings import ModelSettings @@ -18,7 +26,14 @@ from utils import compute_scores import agentlightning -from agentlightning import LLM, LitAgent, NamedResources, Trainer, configure_logger, reward +from agentlightning import ( + LLM, + LitAgent, + NamedResources, + Trainer, + configure_logger, + reward, +) configure_logger() @@ -38,7 +53,9 @@ class RAGAgent(LitAgent): def __init__(self): self.mcp_server_url = "http://127.0.0.1:8099/sse" - async def training_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any: + async def training_rollout_async( + self, task: Any, rollout_id: str, resources: NamedResources + ) -> Any: llm: LLM = resources.get("main_llm") print("Training with model:", llm.model, "on endpoint:", llm.endpoint) async with MCPServerSse( @@ -46,7 +63,9 @@ async def training_rollout_async(self, task: Any, rollout_id: str, resources: Na params={"url": self.mcp_server_url}, ) as server: agent = Agent( - model=LitellmModel(model="hosted_vllm/" + llm.model, base_url=llm.endpoint), + model=LitellmModel( + model="hosted_vllm/" + llm.model, base_url=llm.endpoint + ), model_settings=ModelSettings( max_tokens=4096, temperature=0.7, @@ -58,10 +77,16 @@ async def training_rollout_async(self, task: Any, rollout_id: str, resources: Na result = await Runner.run(agent, task["question"]) answer = result.final_output reward = compute_scores(answer, str(task["answer"])) - print("question:{} answer: {} ground_truth: {} reward: {}".format(task["question"], answer, task["answer"], reward)) + print( + "question:{} answer: {} ground_truth: {} reward: {}".format( + task["question"], answer, task["answer"], reward + ) + ) return reward - async def validation_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any: + async def validation_rollout_async( + self, task: Any, rollout_id: str, resources: NamedResources + ) -> Any: llm: LLM = resources.get("main_llm") resources = { "main_llm": LLM( diff --git a/examples/rag_agent/utils.py b/examples/rag_agent/utils.py index 2210e1ec..87e8462b 100644 --- a/examples/rag_agent/utils.py +++ b/examples/rag_agent/utils.py @@ -13,7 +13,6 @@ def normalize_answer(s): - def remove_articles(text): return re.sub(r"\b(a|an|the)\b", " ", text) @@ -36,9 +35,15 @@ def f1_score(prediction, ground_truth): ZERO_METRIC = (0, 0, 0) - if normalized_prediction in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth: + if ( + normalized_prediction in ["yes", "no", "noanswer"] + and normalized_prediction != normalized_ground_truth + ): return ZERO_METRIC - if normalized_ground_truth in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth: + if ( + normalized_ground_truth in ["yes", "no", "noanswer"] + and normalized_prediction != normalized_ground_truth + ): return ZERO_METRIC prediction_tokens = normalized_prediction.split() @@ -59,10 +64,17 @@ def lenient_f1_score(prediction, ground_truth): ZERO_METRIC = (0, 0, 0) - if normalized_ground_truth in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth: - if normalized_ground_truth == "yes" and ("no" in normalized_prediction or "noanswer" in normalized_prediction): + if ( + normalized_ground_truth in ["yes", "no", "noanswer"] + and normalized_prediction != normalized_ground_truth + ): + if normalized_ground_truth == "yes" and ( + "no" in normalized_prediction or "noanswer" in normalized_prediction + ): return ZERO_METRIC - if normalized_ground_truth == "no" and ("yes" in normalized_prediction or "noanswer" in normalized_prediction): + if normalized_ground_truth == "no" and ( + "yes" in normalized_prediction or "noanswer" in normalized_prediction + ): return ZERO_METRIC prediction_tokens = normalized_prediction.split() @@ -173,13 +185,27 @@ def compute_score(prediction, gold, gold_sentences=None, data_source=None): # return answer_acc -def compute_reward(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): +def compute_reward( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): prediction = solution_str gold = ground_truth - return compute_score(prediction, gold, gold_sentences=gold_sentences, data_source=data_source) - - -def compute_em(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): + return compute_score( + prediction, gold, gold_sentences=gold_sentences, data_source=data_source + ) + + +def compute_em( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): prediction = solution_str gold = ground_truth prompt, response = split_response(prediction) @@ -195,7 +221,13 @@ def compute_em(solution_str=None, ground_truth=None, gold_sentences=None, data_s return em -def compute_cem(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): +def compute_cem( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): prediction = solution_str gold = ground_truth prompt, response = split_response(prediction) @@ -208,7 +240,13 @@ def compute_cem(solution_str=None, ground_truth=None, gold_sentences=None, data_ return cem -def compute_response_cem(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): +def compute_response_cem( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): prediction = solution_str gold = ground_truth prompt, response = split_response(prediction) @@ -221,7 +259,13 @@ def compute_response_cem(solution_str=None, ground_truth=None, gold_sentences=No return cem -def compute_lenient_f1(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): +def compute_lenient_f1( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): prediction = solution_str gold = ground_truth prompt, response = split_response(prediction) @@ -234,7 +278,13 @@ def compute_lenient_f1(solution_str=None, ground_truth=None, gold_sentences=None return f1 -def compute_lenient_response_f1(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): +def compute_lenient_response_f1( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): prediction = solution_str gold = ground_truth prompt, response = split_response(prediction) @@ -251,7 +301,13 @@ def fact_checking_api(prediction, ans): return True # Placeholder for actual fact-checking logic -def compute_f1(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): +def compute_f1( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): prediction = solution_str gold = ground_truth prompt, response = split_response(prediction) @@ -264,7 +320,13 @@ def compute_f1(solution_str=None, ground_truth=None, gold_sentences=None, data_s return f1 -def compute_format(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): +def compute_format( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): prediction = solution_str gold = ground_truth prompt, response = split_response(prediction) @@ -284,7 +346,13 @@ def split_trace(text): return prompt, response -def compute_action_query(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): +def compute_action_query( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): prediction = solution_str gold = ground_truth prompt, trace = split_trace(prediction) @@ -292,7 +360,13 @@ def compute_action_query(solution_str=None, ground_truth=None, gold_sentences=No return res -def compute_action_bm25(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): +def compute_action_bm25( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): prediction = solution_str gold = ground_truth prompt, trace = split_trace(prediction) @@ -300,7 +374,13 @@ def compute_action_bm25(solution_str=None, ground_truth=None, gold_sentences=Non return res -def compute_action_read_pre(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): +def compute_action_read_pre( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): prediction = solution_str gold = ground_truth prompt, trace = split_trace(prediction) @@ -308,7 +388,13 @@ def compute_action_read_pre(solution_str=None, ground_truth=None, gold_sentences return res -def compute_action_read_nxt(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): +def compute_action_read_nxt( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): prediction = solution_str gold = ground_truth prompt, trace = split_trace(prediction) @@ -316,7 +402,13 @@ def compute_action_read_nxt(solution_str=None, ground_truth=None, gold_sentences return res -def compute_action_continue(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): +def compute_action_continue( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): prediction = solution_str gold = ground_truth prompt, trace = split_trace(prediction) @@ -324,7 +416,13 @@ def compute_action_continue(solution_str=None, ground_truth=None, gold_sentences return res -def compute_action_match(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): +def compute_action_match( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): prediction = solution_str gold = ground_truth prompt, trace = split_trace(prediction) @@ -332,7 +430,13 @@ def compute_action_match(solution_str=None, ground_truth=None, gold_sentences=No return res -def compute_total_action_number(solution_str=None, ground_truth=None, gold_sentences=None, data_source=None, extra_info=None): +def compute_total_action_number( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): prediction = solution_str gold = ground_truth prompt, trace = split_trace(prediction) From 9a278844a6b0eb2ffebf7a0568e52c44ce9bc29d Mon Sep 17 00:00:00 2001 From: Zilong Wang Date: Mon, 11 Aug 2025 18:23:35 +0800 Subject: [PATCH 09/10] [fix] pre-commit linter --- examples/rag_agent/rag_agent.py | 12 +++--------- examples/rag_agent/utils.py | 27 ++++++--------------------- 2 files changed, 9 insertions(+), 30 deletions(-) diff --git a/examples/rag_agent/rag_agent.py b/examples/rag_agent/rag_agent.py index bb3243f5..bbca5132 100644 --- a/examples/rag_agent/rag_agent.py +++ b/examples/rag_agent/rag_agent.py @@ -53,9 +53,7 @@ class RAGAgent(LitAgent): def __init__(self): self.mcp_server_url = "http://127.0.0.1:8099/sse" - async def training_rollout_async( - self, task: Any, rollout_id: str, resources: NamedResources - ) -> Any: + async def training_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any: llm: LLM = resources.get("main_llm") print("Training with model:", llm.model, "on endpoint:", llm.endpoint) async with MCPServerSse( @@ -63,9 +61,7 @@ async def training_rollout_async( params={"url": self.mcp_server_url}, ) as server: agent = Agent( - model=LitellmModel( - model="hosted_vllm/" + llm.model, base_url=llm.endpoint - ), + model=LitellmModel(model="hosted_vllm/" + llm.model, base_url=llm.endpoint), model_settings=ModelSettings( max_tokens=4096, temperature=0.7, @@ -84,9 +80,7 @@ async def training_rollout_async( ) return reward - async def validation_rollout_async( - self, task: Any, rollout_id: str, resources: NamedResources - ) -> Any: + async def validation_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any: llm: LLM = resources.get("main_llm") resources = { "main_llm": LLM( diff --git a/examples/rag_agent/utils.py b/examples/rag_agent/utils.py index 87e8462b..13379be7 100644 --- a/examples/rag_agent/utils.py +++ b/examples/rag_agent/utils.py @@ -35,15 +35,9 @@ def f1_score(prediction, ground_truth): ZERO_METRIC = (0, 0, 0) - if ( - normalized_prediction in ["yes", "no", "noanswer"] - and normalized_prediction != normalized_ground_truth - ): + if normalized_prediction in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth: return ZERO_METRIC - if ( - normalized_ground_truth in ["yes", "no", "noanswer"] - and normalized_prediction != normalized_ground_truth - ): + if normalized_ground_truth in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth: return ZERO_METRIC prediction_tokens = normalized_prediction.split() @@ -64,17 +58,10 @@ def lenient_f1_score(prediction, ground_truth): ZERO_METRIC = (0, 0, 0) - if ( - normalized_ground_truth in ["yes", "no", "noanswer"] - and normalized_prediction != normalized_ground_truth - ): - if normalized_ground_truth == "yes" and ( - "no" in normalized_prediction or "noanswer" in normalized_prediction - ): + if normalized_ground_truth in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth: + if normalized_ground_truth == "yes" and ("no" in normalized_prediction or "noanswer" in normalized_prediction): return ZERO_METRIC - if normalized_ground_truth == "no" and ( - "yes" in normalized_prediction or "noanswer" in normalized_prediction - ): + if normalized_ground_truth == "no" and ("yes" in normalized_prediction or "noanswer" in normalized_prediction): return ZERO_METRIC prediction_tokens = normalized_prediction.split() @@ -194,9 +181,7 @@ def compute_reward( ): prediction = solution_str gold = ground_truth - return compute_score( - prediction, gold, gold_sentences=gold_sentences, data_source=data_source - ) + return compute_score(prediction, gold, gold_sentences=gold_sentences, data_source=data_source) def compute_em( From 19194e15fc20aa16cae1cb786b3103987e2774e8 Mon Sep 17 00:00:00 2001 From: Zilong Wang Date: Mon, 11 Aug 2025 18:45:48 +0800 Subject: [PATCH 10/10] [fix] change rag example folder name from rag_agent to rag --- examples/{rag_agent => rag}/README.md | 0 examples/{rag_agent => rag}/rag_agent.py | 0 examples/{rag_agent => rag}/train.sh | 0 examples/{rag_agent => rag}/utils.py | 0 .../wiki_retriever_mcp/wiki_retriever_install.sh | 0 .../{rag_agent => rag}/wiki_retriever_mcp/wiki_retriever_mcp.py | 0 6 files changed, 0 insertions(+), 0 deletions(-) rename examples/{rag_agent => rag}/README.md (100%) rename examples/{rag_agent => rag}/rag_agent.py (100%) rename examples/{rag_agent => rag}/train.sh (100%) rename examples/{rag_agent => rag}/utils.py (100%) rename examples/{rag_agent => rag}/wiki_retriever_mcp/wiki_retriever_install.sh (100%) rename examples/{rag_agent => rag}/wiki_retriever_mcp/wiki_retriever_mcp.py (100%) diff --git a/examples/rag_agent/README.md b/examples/rag/README.md similarity index 100% rename from examples/rag_agent/README.md rename to examples/rag/README.md diff --git a/examples/rag_agent/rag_agent.py b/examples/rag/rag_agent.py similarity index 100% rename from examples/rag_agent/rag_agent.py rename to examples/rag/rag_agent.py diff --git a/examples/rag_agent/train.sh b/examples/rag/train.sh similarity index 100% rename from examples/rag_agent/train.sh rename to examples/rag/train.sh diff --git a/examples/rag_agent/utils.py b/examples/rag/utils.py similarity index 100% rename from examples/rag_agent/utils.py rename to examples/rag/utils.py diff --git a/examples/rag_agent/wiki_retriever_mcp/wiki_retriever_install.sh b/examples/rag/wiki_retriever_mcp/wiki_retriever_install.sh similarity index 100% rename from examples/rag_agent/wiki_retriever_mcp/wiki_retriever_install.sh rename to examples/rag/wiki_retriever_mcp/wiki_retriever_install.sh diff --git a/examples/rag_agent/wiki_retriever_mcp/wiki_retriever_mcp.py b/examples/rag/wiki_retriever_mcp/wiki_retriever_mcp.py similarity index 100% rename from examples/rag_agent/wiki_retriever_mcp/wiki_retriever_mcp.py rename to examples/rag/wiki_retriever_mcp/wiki_retriever_mcp.py