diff --git a/examples/rag/README.md b/examples/rag/README.md
new file mode 100644
index 00000000..19119635
--- /dev/null
+++ b/examples/rag/README.md
@@ -0,0 +1,15 @@
+# RAG Agent Example
+
+This example originally runs on a single node with four GPUs, each requiring at least 40GB of memory.
+
+1. Prepare the RAG dataset in the wiki_retriever_mcp folder. Wiki chunks (`nq_list.pkl`) and Faiss index (`nq_hnsw_faiss_n32e40.index`) are required. (Full wiki dump files are huge, additional information will be provided later)
+2. Prepare the training data in the `data` folder. Download from [here](https://drive.google.com/drive/folders/1hEqOY4EbplUB5ew-8UPFhV_5QU2j7WCN?usp=drive_link). `musique_train.parquet` and `musique_dev_128.parquet` are required.
+3. Set up the environment for wiki retriever MCP: `bash wiki_retriever_install.sh`. This will install the required packages and set up the environment for the wiki retriever MCP.
+4. Start the wiki retriever MCP: `python wiki_retriever_mcp.py`. This will start the wiki retriever MCP server.
+5. Start Ray: `bash ../../scripts/restart_ray.sh`. To use Wandb, you need to set the WANDB_API_KEY environment variable before starting Ray.
+6. Run the agent: `python rag_agent.py`. This automatically launches 12 agent workers by default.
+7. In another terminal, launch the training server: `bash train.sh`.
+
+## Evaluation
+
+Results are coming soon.
diff --git a/examples/rag/rag_agent.py b/examples/rag/rag_agent.py
new file mode 100644
index 00000000..bbca5132
--- /dev/null
+++ b/examples/rag/rag_agent.py
@@ -0,0 +1,96 @@
+from __future__ import annotations
+
+import os
+import re
+import shutil
+import sys
+import tempfile
+import time
+from typing import Any, Literal, Optional
+
+import dotenv
+import termcolor
+from agents import (
+ Agent,
+ Runner,
+ function_tool,
+ gen_trace_id,
+ set_trace_processors,
+ set_tracing_disabled,
+ trace,
+)
+from agents.extensions.models.litellm_model import LitellmModel
+from agents.mcp import MCPServer, MCPServerSse
+from agents.model_settings import ModelSettings
+from agents.tracing.processors import BatchTraceProcessor, ConsoleSpanExporter
+from utils import compute_scores
+
+import agentlightning
+from agentlightning import (
+ LLM,
+ LitAgent,
+ NamedResources,
+ Trainer,
+ configure_logger,
+ reward,
+)
+
+configure_logger()
+
+agent_prompt = """You are an assistant who answers questions using Wikipedia retriever. Answer the question using only the retrieved passages. Verify your answer directly against the text.
+
+After each search:
+- Summarize findings.
+- Decide if info is sufficient.
+ - If sufficient: reply in ... with your answer. The answer must be extremely concise: a single word or a few words only.
+ - If not: suggest the next search needed to fill info gaps. The system will return top 3 relevant Wikipedia chunks.
+- Explain your reasoning for the chosen action.
+
+Repeat as needed. When done, wrap your final, concise answer in tags."""
+
+
+class RAGAgent(LitAgent):
+ def __init__(self):
+ self.mcp_server_url = "http://127.0.0.1:8099/sse"
+
+ async def training_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any:
+ llm: LLM = resources.get("main_llm")
+ print("Training with model:", llm.model, "on endpoint:", llm.endpoint)
+ async with MCPServerSse(
+ name="wiki_retriever_mcp",
+ params={"url": self.mcp_server_url},
+ ) as server:
+ agent = Agent(
+ model=LitellmModel(model="hosted_vllm/" + llm.model, base_url=llm.endpoint),
+ model_settings=ModelSettings(
+ max_tokens=4096,
+ temperature=0.7,
+ ),
+ name="Assistant",
+ instructions=agent_prompt,
+ mcp_servers=[server],
+ )
+ result = await Runner.run(agent, task["question"])
+ answer = result.final_output
+ reward = compute_scores(answer, str(task["answer"]))
+ print(
+ "question:{} answer: {} ground_truth: {} reward: {}".format(
+ task["question"], answer, task["answer"], reward
+ )
+ )
+ return reward
+
+ async def validation_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any:
+ llm: LLM = resources.get("main_llm")
+ resources = {
+ "main_llm": LLM(
+ endpoint=llm.endpoint,
+ model=llm.model,
+ sampling_parameters={"temperature": 0.7},
+ )
+ }
+ return await self.training_rollout_async(task, rollout_id, resources)
+
+
+if __name__ == "__main__":
+ Trainer(n_workers=12).fit(RAGAgent(), "http://localhost:9999/")
diff --git a/examples/rag/train.sh b/examples/rag/train.sh
new file mode 100644
index 00000000..7615b085
--- /dev/null
+++ b/examples/rag/train.sh
@@ -0,0 +1,53 @@
+#!/bin/bash
+
+set -e
+
+export N_GPUS=1
+export BASE_MODEL=Qwen/Qwen3-1.7B
+export DATA_DIR=data
+export ROLLOUT_TP_SIZE=1
+export EXPERIMENT_NAME=rag_agent
+export PROJECT_NAME=AgentLightning
+
+echo "Starting training script..."
+
+python -m agentlightning.verl \
+ algorithm.adv_estimator=grpo \
+ data.train_files=${DATA_DIR}/musique_train.parquet \
+ data.val_files=${DATA_DIR}/musique_dev_128.parquet \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
+ trainer.n_gpus_per_node=${N_GPUS} \
+ data.train_batch_size=32 \
+ actor_rollout_ref.rollout.n=4 \
+ actor_rollout_ref.actor.ppo_mini_batch_size=32 \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.rollout.multi_turn.format=hermes \
+ actor_rollout_ref.model.path=${BASE_MODEL} \
+ data.max_prompt_length=4096 \
+ data.max_response_length=2048 \
+ data.truncation='error' \
+ trainer.val_before_train=True \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.kl_loss_coef=0.000 \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.clip_ratio_low=0.2 \
+ actor_rollout_ref.actor.clip_ratio_high=0.3 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=True \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ algorithm.use_kl_in_reward=False \
+ trainer.critic_warmup=0 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name=${PROJECT_NAME} \
+ trainer.experiment_name=${EXPERIMENT_NAME} \
+ trainer.nnodes=1 \
+ trainer.save_freq=40 \
+ trainer.test_freq=20 \
+ trainer.total_epochs=2 $@
diff --git a/examples/rag/utils.py b/examples/rag/utils.py
new file mode 100644
index 00000000..13379be7
--- /dev/null
+++ b/examples/rag/utils.py
@@ -0,0 +1,442 @@
+import json
+import pickle
+import re
+import string
+import sys
+from collections import Counter
+
+ANS_BEGIN = ""
+ANS_END = ""
+GEN_BEGIN = "<|im_start|>assistant\n"
+FORMAT_SCORE = 0.1
+FORMAT_PUNISH = -2
+
+
+def normalize_answer(s):
+ def remove_articles(text):
+ return re.sub(r"\b(a|an|the)\b", " ", text)
+
+ def white_space_fix(text):
+ return " ".join(text.split())
+
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return "".join(ch for ch in text if ch not in exclude)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+
+def f1_score(prediction, ground_truth):
+ normalized_prediction = normalize_answer(prediction)
+ normalized_ground_truth = normalize_answer(ground_truth)
+
+ ZERO_METRIC = (0, 0, 0)
+
+ if normalized_prediction in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth:
+ return ZERO_METRIC
+ if normalized_ground_truth in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth:
+ return ZERO_METRIC
+
+ prediction_tokens = normalized_prediction.split()
+ ground_truth_tokens = normalized_ground_truth.split()
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
+ num_same = sum(common.values())
+ if num_same == 0:
+ return ZERO_METRIC
+ precision = 1.0 * num_same / len(prediction_tokens)
+ recall = 1.0 * num_same / len(ground_truth_tokens)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1, precision, recall
+
+
+def lenient_f1_score(prediction, ground_truth):
+ normalized_prediction = normalize_answer(prediction)
+ normalized_ground_truth = normalize_answer(ground_truth)
+
+ ZERO_METRIC = (0, 0, 0)
+
+ if normalized_ground_truth in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth:
+ if normalized_ground_truth == "yes" and ("no" in normalized_prediction or "noanswer" in normalized_prediction):
+ return ZERO_METRIC
+ if normalized_ground_truth == "no" and ("yes" in normalized_prediction or "noanswer" in normalized_prediction):
+ return ZERO_METRIC
+
+ prediction_tokens = normalized_prediction.split()
+ ground_truth_tokens = normalized_ground_truth.split()
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
+ num_same = sum(common.values())
+ if num_same == 0:
+ return ZERO_METRIC
+ precision = 1.0 * num_same / len(prediction_tokens)
+ recall = 1.0 * num_same / len(ground_truth_tokens)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1, precision, recall
+
+
+def exact_match_score(prediction, ground_truth):
+ return normalize_answer(prediction) == normalize_answer(ground_truth)
+
+
+def cover_exact_match_score(prediction, ground_truth):
+ return normalize_answer(ground_truth) in normalize_answer(prediction)
+
+
+def extract_answer(response):
+ if ANS_BEGIN not in response or ANS_END not in response:
+ return ""
+ pos1 = response.rfind(ANS_BEGIN)
+ pos2 = response.rfind(ANS_END)
+ assert pos2 != -1
+ if pos1 != -1:
+ ans = response[pos1 + len(ANS_BEGIN) : pos2]
+ else:
+ ans = response[len(ANS_BEGIN) : pos2]
+ return ans
+
+
+def split_response(text):
+ start_response = text.rfind(GEN_BEGIN)
+ response = text[start_response + len(GEN_BEGIN) :]
+ prompt = text[: -len(response)]
+ return prompt, response
+
+
+def extract_recall_chunk(prompt, response):
+ import re
+
+ # 正则表达式,匹配每个search_step内1.和2.后面的内容
+ pattern = r"Retrieved sentences:\s*1\.\s*(.*?)\s*2\.\s*(.*?)(?:\n\s*\d+\.|\n\n|$)"
+
+ # 使用re.findall 提取所有的(s1, s2)
+ origin_recall = re.findall(pattern, prompt, re.DOTALL)
+ sequential_recall = re.findall(pattern, response, re.DOTALL)
+ origin_recall_set = set(s for pair in origin_recall for s in pair)
+ sequential_recall_set = set(s for pair in sequential_recall for s in pair)
+
+ return origin_recall_set, sequential_recall_set
+
+
+import re
+
+
+def extract_retrieved_paragraphs(log_text):
+ # 正则表达式匹配 "Retrieved paragraph:" 后的内容
+ pattern = re.compile(r"Retrieved paragraph:\s*(.*?)\n", re.DOTALL)
+
+ # 提取匹配的段落
+ matches = pattern.findall(log_text)
+ matches = list(set(matches))
+ return matches
+
+
+def compute_score(prediction, gold, gold_sentences=None, data_source=None):
+ # format acc
+ format_acc = FORMAT_SCORE
+
+ prompt, response = split_response(prediction)
+ ans = extract_answer(response)
+ if ans == "":
+ # format score 0.1
+ # if '' not in response or '' not in response:
+ # return 0.0
+ # return 0.0
+ delimiter = "<|im_start|>assistant"
+ last_time_ans = response.split(delimiter)[-1]
+ if "" not in last_time_ans:
+ return 0.0
+ return format_acc
+
+ # answer acc
+ em, cem = exact_match_score(ans, gold), cover_exact_match_score(ans, gold)
+ f1, prec, recall = f1_score(ans, gold)
+
+ if fact_checking_api(prediction, ans):
+ answer_acc = max(float(em), f1)
+ else:
+ answer_acc = 0
+ # # search acc
+ # if gold_sentences and search_weight:
+ # origin_recall_set, sequential_recall_set = extract_recall_chunk(prompt, response)
+ # gold_sentences_set = set(gold_sentences) - origin_recall_set
+ # matched = gold_sentences_set & sequential_recall_set
+ # search_acc = len(matched) / len(gold_sentences_set) if len(gold_sentences_set) != 0 else 1.0
+ # # print(f's_acc {search_acc}|a_acc {answer_acc=}| score {format_acc + (1 - format_acc) * (search_weight + (1 - search_weight) * answer_acc)} |m_len {len(matched)}|g_len {len(gold_sentences_set)}|o_len {len(origin_recall_set)}|s_len {len(sequential_recall_set)}|{gold_sentences_set}|{sequential_recall_set}')
+ # if search_acc < 1:
+ # return format_acc + (1 - format_acc) * search_weight * search_acc
+ # # print(f'SCORE: {score} | {ans} | {gold} | {prediction}' )
+
+ return format_acc + (1 - format_acc) * answer_acc
+ # return answer_acc
+
+
+def compute_reward(
+ solution_str=None,
+ ground_truth=None,
+ gold_sentences=None,
+ data_source=None,
+ extra_info=None,
+):
+ prediction = solution_str
+ gold = ground_truth
+ return compute_score(prediction, gold, gold_sentences=gold_sentences, data_source=data_source)
+
+
+def compute_em(
+ solution_str=None,
+ ground_truth=None,
+ gold_sentences=None,
+ data_source=None,
+ extra_info=None,
+):
+ prediction = solution_str
+ gold = ground_truth
+ prompt, response = split_response(prediction)
+ ans = extract_answer(response)
+ if ans == "":
+ # format score 0.1
+ # if '' not in response or '' not in response:
+ # return 0.0
+ return 0.0
+
+ # answer acc
+ em = exact_match_score(ans, gold)
+ return em
+
+
+def compute_cem(
+ solution_str=None,
+ ground_truth=None,
+ gold_sentences=None,
+ data_source=None,
+ extra_info=None,
+):
+ prediction = solution_str
+ gold = ground_truth
+ prompt, response = split_response(prediction)
+ ans = extract_answer(response)
+ if ans == "":
+ return 0.0
+
+ # answer acc
+ cem = cover_exact_match_score(ans, gold)
+ return cem
+
+
+def compute_response_cem(
+ solution_str=None,
+ ground_truth=None,
+ gold_sentences=None,
+ data_source=None,
+ extra_info=None,
+):
+ prediction = solution_str
+ gold = ground_truth
+ prompt, response = split_response(prediction)
+ ans = response
+ if ans == "":
+ return 0.0
+
+ # answer acc
+ cem = cover_exact_match_score(ans, gold)
+ return cem
+
+
+def compute_lenient_f1(
+ solution_str=None,
+ ground_truth=None,
+ gold_sentences=None,
+ data_source=None,
+ extra_info=None,
+):
+ prediction = solution_str
+ gold = ground_truth
+ prompt, response = split_response(prediction)
+ ans = extract_answer(response)
+ if ans == "":
+ return 0.0
+
+ # answer acc
+ f1, prec, recall = lenient_f1_score(ans, gold)
+ return f1
+
+
+def compute_lenient_response_f1(
+ solution_str=None,
+ ground_truth=None,
+ gold_sentences=None,
+ data_source=None,
+ extra_info=None,
+):
+ prediction = solution_str
+ gold = ground_truth
+ prompt, response = split_response(prediction)
+ ans = response
+ if ans == "":
+ return 0.0
+
+ # answer acc
+ f1, prec, recall = lenient_f1_score(ans, gold)
+ return f1
+
+
+def fact_checking_api(prediction, ans):
+ return True # Placeholder for actual fact-checking logic
+
+
+def compute_f1(
+ solution_str=None,
+ ground_truth=None,
+ gold_sentences=None,
+ data_source=None,
+ extra_info=None,
+):
+ prediction = solution_str
+ gold = ground_truth
+ prompt, response = split_response(prediction)
+ ans = extract_answer(response)
+ if ans == "":
+ return 0.0
+
+ # answer acc
+ f1, prec, recall = f1_score(ans, gold)
+ return f1
+
+
+def compute_format(
+ solution_str=None,
+ ground_truth=None,
+ gold_sentences=None,
+ data_source=None,
+ extra_info=None,
+):
+ prediction = solution_str
+ gold = ground_truth
+ prompt, response = split_response(prediction)
+ ans = extract_answer(response)
+ if ans == "":
+ delimiter = "<|im_start|>assistant"
+ last_time_ans = response.split(delimiter)[-1]
+ if "" not in last_time_ans:
+ return 0
+ return FORMAT_SCORE
+
+
+def split_trace(text):
+ start_response = text.find(GEN_BEGIN)
+ response = text[start_response + len(GEN_BEGIN) :]
+ prompt = text[: -len(response)]
+ return prompt, response
+
+
+def compute_action_query(
+ solution_str=None,
+ ground_truth=None,
+ gold_sentences=None,
+ data_source=None,
+ extra_info=None,
+):
+ prediction = solution_str
+ gold = ground_truth
+ prompt, trace = split_trace(prediction)
+ res = min(trace.count("") + trace.count(""))
+ return res
+
+
+def compute_action_bm25(
+ solution_str=None,
+ ground_truth=None,
+ gold_sentences=None,
+ data_source=None,
+ extra_info=None,
+):
+ prediction = solution_str
+ gold = ground_truth
+ prompt, trace = split_trace(prediction)
+ res = min(trace.count(""))
+ return res
+
+
+def compute_action_read_pre(
+ solution_str=None,
+ ground_truth=None,
+ gold_sentences=None,
+ data_source=None,
+ extra_info=None,
+):
+ prediction = solution_str
+ gold = ground_truth
+ prompt, trace = split_trace(prediction)
+ res = min(trace.count(""))
+ return res
+
+
+def compute_action_read_nxt(
+ solution_str=None,
+ ground_truth=None,
+ gold_sentences=None,
+ data_source=None,
+ extra_info=None,
+):
+ prediction = solution_str
+ gold = ground_truth
+ prompt, trace = split_trace(prediction)
+ res = min(trace.count(""))
+ return res
+
+
+def compute_action_continue(
+ solution_str=None,
+ ground_truth=None,
+ gold_sentences=None,
+ data_source=None,
+ extra_info=None,
+):
+ prediction = solution_str
+ gold = ground_truth
+ prompt, trace = split_trace(prediction)
+ res = min(trace.count(", continue"), trace.count(""))
+ return res
+
+
+def compute_action_match(
+ solution_str=None,
+ ground_truth=None,
+ gold_sentences=None,
+ data_source=None,
+ extra_info=None,
+):
+ prediction = solution_str
+ gold = ground_truth
+ prompt, trace = split_trace(prediction)
+ res = min(trace.count(', match_phrase="'), trace.count(""))
+ return res
+
+
+def compute_total_action_number(
+ solution_str=None,
+ ground_truth=None,
+ gold_sentences=None,
+ data_source=None,
+ extra_info=None,
+):
+ prediction = solution_str
+ gold = ground_truth
+ prompt, trace = split_trace(prediction)
+ res = min(trace.count(""))
+ return res
+
+
+# define reward functions for evaluation
+
+
+def compute_scores(answer, ground_truth):
+ parsed_answer = extract_answer(answer)
+ if parsed_answer is None:
+ return -0.1
+ f1, precision, recall = f1_score(parsed_answer, ground_truth)
+ # em = float(exact_match_score(parsed_answer, ground_truth))
+ # cem = float(cover_exact_match_score(answer, ground_truth))
+ return f1
diff --git a/examples/rag/wiki_retriever_mcp/wiki_retriever_install.sh b/examples/rag/wiki_retriever_mcp/wiki_retriever_install.sh
new file mode 100644
index 00000000..7b22cdfc
--- /dev/null
+++ b/examples/rag/wiki_retriever_mcp/wiki_retriever_install.sh
@@ -0,0 +1,3 @@
+conda create -n mcp_server python=3.12 -y
+conda activate mcp_server
+pip install faiss-cpu==1.11.0 fastmcp==2.5.1 sentence-transformers==4.1.0
\ No newline at end of file
diff --git a/examples/rag/wiki_retriever_mcp/wiki_retriever_mcp.py b/examples/rag/wiki_retriever_mcp/wiki_retriever_mcp.py
new file mode 100644
index 00000000..5affb65c
--- /dev/null
+++ b/examples/rag/wiki_retriever_mcp/wiki_retriever_mcp.py
@@ -0,0 +1,48 @@
+import faiss
+from sentence_transformers import SentenceTransformer
+import pickle
+from fastmcp import FastMCP
+
+# index = faiss.read_index("/mnt/input/agent_lightning/nq_hnsw_faiss_n32e40.index")
+index = faiss.read_index("nq_hnsw_faiss_n32e40.index")
+print("Index loaded successfully.")
+
+model = SentenceTransformer("BAAI/bge-large-en-v1.5")
+print("Model loaded successfully.")
+
+# with open('/mnt/input/agent_lightning/nq_list.pkl', 'rb') as f:
+with open("nq_list.pkl", "rb") as f:
+ chunks = pickle.load(f)
+print("Chunks loaded successfully.")
+
+mcp = FastMCP(name="wiki retrieval mcp")
+
+
+@mcp.tool(
+ name="retrieve",
+ description="retrieve relevant chunks from the wikipedia",
+)
+def retrieve(query: str) -> list:
+ """
+ Retrieve relevant chunks from the Wikipedia dataset.
+
+ Args:
+ query (str): The query string to search for.
+
+ Returns:
+ list: A list of dictionaries containing the retrieved chunks and their metadata.
+ """
+ top_k = 4 # Number of top results to return
+ embedding = model.encode([query], normalize_embeddings=True)
+ D, I = index.search(embedding, top_k)
+
+ results = []
+ for i in range(top_k):
+ if I[0][i] != -1:
+ chunk = chunks[I[0][i]]
+ results.append({"chunk": chunk, "chunk_id": I[0][i], "distance": D[0][i]})
+ return results
+
+
+if __name__ == "__main__":
+ mcp.run(transport="sse", host="127.0.0.1", port=8099)