Skip to content
15 changes: 15 additions & 0 deletions examples/rag_agent/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# RAG Agent Example

This example originally runs on a single node with four GPUs, each requiring at least 40GB of memory.

1. Prepare the RAG dataset in the wiki_retriever_mcp folder. Wiki chunks (`nq_list.pkl`) and Faiss index (`nq_hnsw_faiss_n32e40.index`) are required. (Full wiki dump files are huge, additional information will be provided later)
2. Prepare the training data in the `data` folder. Download from [here](https://drive.google.com/drive/folders/1hEqOY4EbplUB5ew-8UPFhV_5QU2j7WCN?usp=drive_link). `musique_train.parquet` and `musique_dev_128.parquet` are required.
3. Set up the environment for wiki retriever MCP: `bash wiki_retriever_install.sh`. This will install the required packages and set up the environment for the wiki retriever MCP.
4. Start the wiki retriever MCP: `python wiki_retriever_mcp.py`. This will start the wiki retriever MCP server.
5. Start Ray: `bash ../../scripts/restart_ray.sh`. To use Wandb, you need to set the WANDB_API_KEY environment variable before starting Ray.
6. Run the agent: `python rag_agent.py`. This automatically launches 12 agent workers by default.
7. In another terminal, launch the training server: `bash train.sh`.

## Evaluation

Results are coming soon.
77 changes: 77 additions & 0 deletions examples/rag_agent/rag_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from __future__ import annotations

import os
import re
import shutil
import sys
import tempfile
import time
from typing import Any, Literal, Optional

import dotenv
import termcolor
from agents import Agent, Runner, function_tool, gen_trace_id, set_trace_processors, set_tracing_disabled, trace
from agents.extensions.models.litellm_model import LitellmModel
from agents.mcp import MCPServer, MCPServerSse
from agents.model_settings import ModelSettings
from agents.tracing.processors import BatchTraceProcessor, ConsoleSpanExporter
from utils import compute_scores

import agentlightning
from agentlightning import LLM, LitAgent, NamedResources, Trainer, configure_logger, reward

configure_logger()

agent_prompt = """You are an assistant who answers questions using Wikipedia retriever. Answer the question using only the retrieved passages. Verify your answer directly against the text.

After each search:
- Summarize findings.
- Decide if info is sufficient.
- If sufficient: reply in <answer>...</answer> with your answer. The answer must be extremely concise: a single word or a few words only.
- If not: suggest the next search needed to fill info gaps. The system will return top 3 relevant Wikipedia chunks.
- Explain your reasoning for the chosen action.

Repeat as needed. When done, wrap your final, concise answer in <answer> tags."""


class RAGAgent(LitAgent):
def __init__(self):
self.mcp_server_url = "http://127.0.0.1:8099/sse"

async def training_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any:
llm: LLM = resources.get("main_llm")
print("Training with model:", llm.model, "on endpoint:", llm.endpoint)
async with MCPServerSse(
name="wiki_retriever_mcp",
params={"url": self.mcp_server_url},
) as server:
agent = Agent(
model=LitellmModel(model="hosted_vllm/" + llm.model, base_url=llm.endpoint),
model_settings=ModelSettings(
max_tokens=4096,
temperature=0.7,
),
name="Assistant",
instructions=agent_prompt,
mcp_servers=[server],
)
result = await Runner.run(agent, task["question"])
answer = result.final_output
reward = compute_scores(answer, str(task["answer"]))
print("question:{} answer: {} ground_truth: {} reward: {}".format(task["question"], answer, task["answer"], reward))
return reward

async def validation_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any:
llm: LLM = resources.get("main_llm")
resources = {
"main_llm": LLM(
endpoint=llm.endpoint,
model=llm.model,
sampling_parameters={"temperature": 0.7},
)
}
return await self.training_rollout_async(task, rollout_id, resources)


if __name__ == "__main__":
Trainer(n_workers=12).fit(RAGAgent(), "http://localhost:9999/")
53 changes: 53 additions & 0 deletions examples/rag_agent/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/bin/bash

set -e

export N_GPUS=1
export BASE_MODEL=Qwen/Qwen3-1.7B
export DATA_DIR=data
export ROLLOUT_TP_SIZE=1
export EXPERIMENT_NAME=rag_agent
export PROJECT_NAME=AgentLightning

echo "Starting training script..."

python -m agentlightning.verl \
algorithm.adv_estimator=grpo \
data.train_files=${DATA_DIR}/musique_train.parquet \
data.val_files=${DATA_DIR}/musique_dev_128.parquet \
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
trainer.n_gpus_per_node=${N_GPUS} \
data.train_batch_size=32 \
actor_rollout_ref.rollout.n=4 \
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.multi_turn.format=hermes \
actor_rollout_ref.model.path=${BASE_MODEL} \
data.max_prompt_length=4096 \
data.max_response_length=2048 \
data.truncation='error' \
trainer.val_before_train=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.actor.kl_loss_coef=0.000 \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.clip_ratio_low=0.2 \
actor_rollout_ref.actor.clip_ratio_high=0.3 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name=${PROJECT_NAME} \
trainer.experiment_name=${EXPERIMENT_NAME} \
trainer.nnodes=1 \
trainer.save_freq=40 \
trainer.test_freq=20 \
trainer.total_epochs=2 $@
Loading
Loading