diff --git a/examples/rag/README.md b/examples/rag/README.md new file mode 100644 index 00000000..19119635 --- /dev/null +++ b/examples/rag/README.md @@ -0,0 +1,15 @@ +# RAG Agent Example + +This example originally runs on a single node with four GPUs, each requiring at least 40GB of memory. + +1. Prepare the RAG dataset in the wiki_retriever_mcp folder. Wiki chunks (`nq_list.pkl`) and Faiss index (`nq_hnsw_faiss_n32e40.index`) are required. (Full wiki dump files are huge, additional information will be provided later) +2. Prepare the training data in the `data` folder. Download from [here](https://drive.google.com/drive/folders/1hEqOY4EbplUB5ew-8UPFhV_5QU2j7WCN?usp=drive_link). `musique_train.parquet` and `musique_dev_128.parquet` are required. +3. Set up the environment for wiki retriever MCP: `bash wiki_retriever_install.sh`. This will install the required packages and set up the environment for the wiki retriever MCP. +4. Start the wiki retriever MCP: `python wiki_retriever_mcp.py`. This will start the wiki retriever MCP server. +5. Start Ray: `bash ../../scripts/restart_ray.sh`. To use Wandb, you need to set the WANDB_API_KEY environment variable before starting Ray. +6. Run the agent: `python rag_agent.py`. This automatically launches 12 agent workers by default. +7. In another terminal, launch the training server: `bash train.sh`. + +## Evaluation + +Results are coming soon. diff --git a/examples/rag/rag_agent.py b/examples/rag/rag_agent.py new file mode 100644 index 00000000..bbca5132 --- /dev/null +++ b/examples/rag/rag_agent.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import os +import re +import shutil +import sys +import tempfile +import time +from typing import Any, Literal, Optional + +import dotenv +import termcolor +from agents import ( + Agent, + Runner, + function_tool, + gen_trace_id, + set_trace_processors, + set_tracing_disabled, + trace, +) +from agents.extensions.models.litellm_model import LitellmModel +from agents.mcp import MCPServer, MCPServerSse +from agents.model_settings import ModelSettings +from agents.tracing.processors import BatchTraceProcessor, ConsoleSpanExporter +from utils import compute_scores + +import agentlightning +from agentlightning import ( + LLM, + LitAgent, + NamedResources, + Trainer, + configure_logger, + reward, +) + +configure_logger() + +agent_prompt = """You are an assistant who answers questions using Wikipedia retriever. Answer the question using only the retrieved passages. Verify your answer directly against the text. + +After each search: +- Summarize findings. +- Decide if info is sufficient. + - If sufficient: reply in ... with your answer. The answer must be extremely concise: a single word or a few words only. + - If not: suggest the next search needed to fill info gaps. The system will return top 3 relevant Wikipedia chunks. +- Explain your reasoning for the chosen action. + +Repeat as needed. When done, wrap your final, concise answer in tags.""" + + +class RAGAgent(LitAgent): + def __init__(self): + self.mcp_server_url = "http://127.0.0.1:8099/sse" + + async def training_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any: + llm: LLM = resources.get("main_llm") + print("Training with model:", llm.model, "on endpoint:", llm.endpoint) + async with MCPServerSse( + name="wiki_retriever_mcp", + params={"url": self.mcp_server_url}, + ) as server: + agent = Agent( + model=LitellmModel(model="hosted_vllm/" + llm.model, base_url=llm.endpoint), + model_settings=ModelSettings( + max_tokens=4096, + temperature=0.7, + ), + name="Assistant", + instructions=agent_prompt, + mcp_servers=[server], + ) + result = await Runner.run(agent, task["question"]) + answer = result.final_output + reward = compute_scores(answer, str(task["answer"])) + print( + "question:{} answer: {} ground_truth: {} reward: {}".format( + task["question"], answer, task["answer"], reward + ) + ) + return reward + + async def validation_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any: + llm: LLM = resources.get("main_llm") + resources = { + "main_llm": LLM( + endpoint=llm.endpoint, + model=llm.model, + sampling_parameters={"temperature": 0.7}, + ) + } + return await self.training_rollout_async(task, rollout_id, resources) + + +if __name__ == "__main__": + Trainer(n_workers=12).fit(RAGAgent(), "http://localhost:9999/") diff --git a/examples/rag/train.sh b/examples/rag/train.sh new file mode 100644 index 00000000..7615b085 --- /dev/null +++ b/examples/rag/train.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +set -e + +export N_GPUS=1 +export BASE_MODEL=Qwen/Qwen3-1.7B +export DATA_DIR=data +export ROLLOUT_TP_SIZE=1 +export EXPERIMENT_NAME=rag_agent +export PROJECT_NAME=AgentLightning + +echo "Starting training script..." + +python -m agentlightning.verl \ + algorithm.adv_estimator=grpo \ + data.train_files=${DATA_DIR}/musique_train.parquet \ + data.val_files=${DATA_DIR}/musique_dev_128.parquet \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \ + trainer.n_gpus_per_node=${N_GPUS} \ + data.train_batch_size=32 \ + actor_rollout_ref.rollout.n=4 \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.multi_turn.format=hermes \ + actor_rollout_ref.model.path=${BASE_MODEL} \ + data.max_prompt_length=4096 \ + data.max_response_length=2048 \ + data.truncation='error' \ + trainer.val_before_train=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0.000 \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.clip_ratio_low=0.2 \ + actor_rollout_ref.actor.clip_ratio_high=0.3 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${PROJECT_NAME} \ + trainer.experiment_name=${EXPERIMENT_NAME} \ + trainer.nnodes=1 \ + trainer.save_freq=40 \ + trainer.test_freq=20 \ + trainer.total_epochs=2 $@ diff --git a/examples/rag/utils.py b/examples/rag/utils.py new file mode 100644 index 00000000..13379be7 --- /dev/null +++ b/examples/rag/utils.py @@ -0,0 +1,442 @@ +import json +import pickle +import re +import string +import sys +from collections import Counter + +ANS_BEGIN = "" +ANS_END = "" +GEN_BEGIN = "<|im_start|>assistant\n" +FORMAT_SCORE = 0.1 +FORMAT_PUNISH = -2 + + +def normalize_answer(s): + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(prediction, ground_truth): + normalized_prediction = normalize_answer(prediction) + normalized_ground_truth = normalize_answer(ground_truth) + + ZERO_METRIC = (0, 0, 0) + + if normalized_prediction in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth: + return ZERO_METRIC + if normalized_ground_truth in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth: + return ZERO_METRIC + + prediction_tokens = normalized_prediction.split() + ground_truth_tokens = normalized_ground_truth.split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return ZERO_METRIC + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1, precision, recall + + +def lenient_f1_score(prediction, ground_truth): + normalized_prediction = normalize_answer(prediction) + normalized_ground_truth = normalize_answer(ground_truth) + + ZERO_METRIC = (0, 0, 0) + + if normalized_ground_truth in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth: + if normalized_ground_truth == "yes" and ("no" in normalized_prediction or "noanswer" in normalized_prediction): + return ZERO_METRIC + if normalized_ground_truth == "no" and ("yes" in normalized_prediction or "noanswer" in normalized_prediction): + return ZERO_METRIC + + prediction_tokens = normalized_prediction.split() + ground_truth_tokens = normalized_ground_truth.split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return ZERO_METRIC + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1, precision, recall + + +def exact_match_score(prediction, ground_truth): + return normalize_answer(prediction) == normalize_answer(ground_truth) + + +def cover_exact_match_score(prediction, ground_truth): + return normalize_answer(ground_truth) in normalize_answer(prediction) + + +def extract_answer(response): + if ANS_BEGIN not in response or ANS_END not in response: + return "" + pos1 = response.rfind(ANS_BEGIN) + pos2 = response.rfind(ANS_END) + assert pos2 != -1 + if pos1 != -1: + ans = response[pos1 + len(ANS_BEGIN) : pos2] + else: + ans = response[len(ANS_BEGIN) : pos2] + return ans + + +def split_response(text): + start_response = text.rfind(GEN_BEGIN) + response = text[start_response + len(GEN_BEGIN) :] + prompt = text[: -len(response)] + return prompt, response + + +def extract_recall_chunk(prompt, response): + import re + + # 正则表达式,匹配每个search_step内1.和2.后面的内容 + pattern = r"Retrieved sentences:\s*1\.\s*(.*?)\s*2\.\s*(.*?)(?:\n\s*\d+\.|\n\n|$)" + + # 使用re.findall 提取所有的(s1, s2) + origin_recall = re.findall(pattern, prompt, re.DOTALL) + sequential_recall = re.findall(pattern, response, re.DOTALL) + origin_recall_set = set(s for pair in origin_recall for s in pair) + sequential_recall_set = set(s for pair in sequential_recall for s in pair) + + return origin_recall_set, sequential_recall_set + + +import re + + +def extract_retrieved_paragraphs(log_text): + # 正则表达式匹配 "Retrieved paragraph:" 后的内容 + pattern = re.compile(r"Retrieved paragraph:\s*(.*?)\n", re.DOTALL) + + # 提取匹配的段落 + matches = pattern.findall(log_text) + matches = list(set(matches)) + return matches + + +def compute_score(prediction, gold, gold_sentences=None, data_source=None): + # format acc + format_acc = FORMAT_SCORE + + prompt, response = split_response(prediction) + ans = extract_answer(response) + if ans == "": + # format score 0.1 + # if '' not in response or '' not in response: + # return 0.0 + # return 0.0 + delimiter = "<|im_start|>assistant" + last_time_ans = response.split(delimiter)[-1] + if "" not in last_time_ans: + return 0.0 + return format_acc + + # answer acc + em, cem = exact_match_score(ans, gold), cover_exact_match_score(ans, gold) + f1, prec, recall = f1_score(ans, gold) + + if fact_checking_api(prediction, ans): + answer_acc = max(float(em), f1) + else: + answer_acc = 0 + # # search acc + # if gold_sentences and search_weight: + # origin_recall_set, sequential_recall_set = extract_recall_chunk(prompt, response) + # gold_sentences_set = set(gold_sentences) - origin_recall_set + # matched = gold_sentences_set & sequential_recall_set + # search_acc = len(matched) / len(gold_sentences_set) if len(gold_sentences_set) != 0 else 1.0 + # # print(f's_acc {search_acc}|a_acc {answer_acc=}| score {format_acc + (1 - format_acc) * (search_weight + (1 - search_weight) * answer_acc)} |m_len {len(matched)}|g_len {len(gold_sentences_set)}|o_len {len(origin_recall_set)}|s_len {len(sequential_recall_set)}|{gold_sentences_set}|{sequential_recall_set}') + # if search_acc < 1: + # return format_acc + (1 - format_acc) * search_weight * search_acc + # # print(f'SCORE: {score} | {ans} | {gold} | {prediction}' ) + + return format_acc + (1 - format_acc) * answer_acc + # return answer_acc + + +def compute_reward( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): + prediction = solution_str + gold = ground_truth + return compute_score(prediction, gold, gold_sentences=gold_sentences, data_source=data_source) + + +def compute_em( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): + prediction = solution_str + gold = ground_truth + prompt, response = split_response(prediction) + ans = extract_answer(response) + if ans == "": + # format score 0.1 + # if '' not in response or '' not in response: + # return 0.0 + return 0.0 + + # answer acc + em = exact_match_score(ans, gold) + return em + + +def compute_cem( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): + prediction = solution_str + gold = ground_truth + prompt, response = split_response(prediction) + ans = extract_answer(response) + if ans == "": + return 0.0 + + # answer acc + cem = cover_exact_match_score(ans, gold) + return cem + + +def compute_response_cem( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): + prediction = solution_str + gold = ground_truth + prompt, response = split_response(prediction) + ans = response + if ans == "": + return 0.0 + + # answer acc + cem = cover_exact_match_score(ans, gold) + return cem + + +def compute_lenient_f1( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): + prediction = solution_str + gold = ground_truth + prompt, response = split_response(prediction) + ans = extract_answer(response) + if ans == "": + return 0.0 + + # answer acc + f1, prec, recall = lenient_f1_score(ans, gold) + return f1 + + +def compute_lenient_response_f1( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): + prediction = solution_str + gold = ground_truth + prompt, response = split_response(prediction) + ans = response + if ans == "": + return 0.0 + + # answer acc + f1, prec, recall = lenient_f1_score(ans, gold) + return f1 + + +def fact_checking_api(prediction, ans): + return True # Placeholder for actual fact-checking logic + + +def compute_f1( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): + prediction = solution_str + gold = ground_truth + prompt, response = split_response(prediction) + ans = extract_answer(response) + if ans == "": + return 0.0 + + # answer acc + f1, prec, recall = f1_score(ans, gold) + return f1 + + +def compute_format( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): + prediction = solution_str + gold = ground_truth + prompt, response = split_response(prediction) + ans = extract_answer(response) + if ans == "": + delimiter = "<|im_start|>assistant" + last_time_ans = response.split(delimiter)[-1] + if "" not in last_time_ans: + return 0 + return FORMAT_SCORE + + +def split_trace(text): + start_response = text.find(GEN_BEGIN) + response = text[start_response + len(GEN_BEGIN) :] + prompt = text[: -len(response)] + return prompt, response + + +def compute_action_query( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): + prediction = solution_str + gold = ground_truth + prompt, trace = split_trace(prediction) + res = min(trace.count("") + trace.count("")) + return res + + +def compute_action_bm25( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): + prediction = solution_str + gold = ground_truth + prompt, trace = split_trace(prediction) + res = min(trace.count("")) + return res + + +def compute_action_read_pre( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): + prediction = solution_str + gold = ground_truth + prompt, trace = split_trace(prediction) + res = min(trace.count("")) + return res + + +def compute_action_read_nxt( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): + prediction = solution_str + gold = ground_truth + prompt, trace = split_trace(prediction) + res = min(trace.count("")) + return res + + +def compute_action_continue( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): + prediction = solution_str + gold = ground_truth + prompt, trace = split_trace(prediction) + res = min(trace.count(", continue"), trace.count("")) + return res + + +def compute_action_match( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): + prediction = solution_str + gold = ground_truth + prompt, trace = split_trace(prediction) + res = min(trace.count(', match_phrase="'), trace.count("")) + return res + + +def compute_total_action_number( + solution_str=None, + ground_truth=None, + gold_sentences=None, + data_source=None, + extra_info=None, +): + prediction = solution_str + gold = ground_truth + prompt, trace = split_trace(prediction) + res = min(trace.count("")) + return res + + +# define reward functions for evaluation + + +def compute_scores(answer, ground_truth): + parsed_answer = extract_answer(answer) + if parsed_answer is None: + return -0.1 + f1, precision, recall = f1_score(parsed_answer, ground_truth) + # em = float(exact_match_score(parsed_answer, ground_truth)) + # cem = float(cover_exact_match_score(answer, ground_truth)) + return f1 diff --git a/examples/rag/wiki_retriever_mcp/wiki_retriever_install.sh b/examples/rag/wiki_retriever_mcp/wiki_retriever_install.sh new file mode 100644 index 00000000..7b22cdfc --- /dev/null +++ b/examples/rag/wiki_retriever_mcp/wiki_retriever_install.sh @@ -0,0 +1,3 @@ +conda create -n mcp_server python=3.12 -y +conda activate mcp_server +pip install faiss-cpu==1.11.0 fastmcp==2.5.1 sentence-transformers==4.1.0 \ No newline at end of file diff --git a/examples/rag/wiki_retriever_mcp/wiki_retriever_mcp.py b/examples/rag/wiki_retriever_mcp/wiki_retriever_mcp.py new file mode 100644 index 00000000..5affb65c --- /dev/null +++ b/examples/rag/wiki_retriever_mcp/wiki_retriever_mcp.py @@ -0,0 +1,48 @@ +import faiss +from sentence_transformers import SentenceTransformer +import pickle +from fastmcp import FastMCP + +# index = faiss.read_index("/mnt/input/agent_lightning/nq_hnsw_faiss_n32e40.index") +index = faiss.read_index("nq_hnsw_faiss_n32e40.index") +print("Index loaded successfully.") + +model = SentenceTransformer("BAAI/bge-large-en-v1.5") +print("Model loaded successfully.") + +# with open('/mnt/input/agent_lightning/nq_list.pkl', 'rb') as f: +with open("nq_list.pkl", "rb") as f: + chunks = pickle.load(f) +print("Chunks loaded successfully.") + +mcp = FastMCP(name="wiki retrieval mcp") + + +@mcp.tool( + name="retrieve", + description="retrieve relevant chunks from the wikipedia", +) +def retrieve(query: str) -> list: + """ + Retrieve relevant chunks from the Wikipedia dataset. + + Args: + query (str): The query string to search for. + + Returns: + list: A list of dictionaries containing the retrieved chunks and their metadata. + """ + top_k = 4 # Number of top results to return + embedding = model.encode([query], normalize_embeddings=True) + D, I = index.search(embedding, top_k) + + results = [] + for i in range(top_k): + if I[0][i] != -1: + chunk = chunks[I[0][i]] + results.append({"chunk": chunk, "chunk_id": I[0][i], "distance": D[0][i]}) + return results + + +if __name__ == "__main__": + mcp.run(transport="sse", host="127.0.0.1", port=8099)