Skip to content
15 changes: 15 additions & 0 deletions examples/rag/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# RAG Agent Example

This example originally runs on a single node with four GPUs, each requiring at least 40GB of memory.

1. Prepare the RAG dataset in the wiki_retriever_mcp folder. Wiki chunks (`nq_list.pkl`) and Faiss index (`nq_hnsw_faiss_n32e40.index`) are required. (Full wiki dump files are huge, additional information will be provided later)
2. Prepare the training data in the `data` folder. Download from [here](https://drive.google.com/drive/folders/1hEqOY4EbplUB5ew-8UPFhV_5QU2j7WCN?usp=drive_link). `musique_train.parquet` and `musique_dev_128.parquet` are required.
3. Set up the environment for wiki retriever MCP: `bash wiki_retriever_install.sh`. This will install the required packages and set up the environment for the wiki retriever MCP.
4. Start the wiki retriever MCP: `python wiki_retriever_mcp.py`. This will start the wiki retriever MCP server.
5. Start Ray: `bash ../../scripts/restart_ray.sh`. To use Wandb, you need to set the WANDB_API_KEY environment variable before starting Ray.
6. Run the agent: `python rag_agent.py`. This automatically launches 12 agent workers by default.
7. In another terminal, launch the training server: `bash train.sh`.

## Evaluation

Results are coming soon.
96 changes: 96 additions & 0 deletions examples/rag/rag_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations

import os
import re
import shutil
import sys
import tempfile
import time
from typing import Any, Literal, Optional

import dotenv
import termcolor
from agents import (
Agent,
Runner,
function_tool,
gen_trace_id,
set_trace_processors,
set_tracing_disabled,
trace,
)
from agents.extensions.models.litellm_model import LitellmModel
from agents.mcp import MCPServer, MCPServerSse
from agents.model_settings import ModelSettings
from agents.tracing.processors import BatchTraceProcessor, ConsoleSpanExporter
from utils import compute_scores

import agentlightning
from agentlightning import (
LLM,
LitAgent,
NamedResources,
Trainer,
configure_logger,
reward,
)

configure_logger()

agent_prompt = """You are an assistant who answers questions using Wikipedia retriever. Answer the question using only the retrieved passages. Verify your answer directly against the text.

After each search:
- Summarize findings.
- Decide if info is sufficient.
- If sufficient: reply in <answer>...</answer> with your answer. The answer must be extremely concise: a single word or a few words only.
- If not: suggest the next search needed to fill info gaps. The system will return top 3 relevant Wikipedia chunks.
- Explain your reasoning for the chosen action.

Repeat as needed. When done, wrap your final, concise answer in <answer> tags."""


class RAGAgent(LitAgent):
def __init__(self):
self.mcp_server_url = "http://127.0.0.1:8099/sse"

async def training_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any:
llm: LLM = resources.get("main_llm")
print("Training with model:", llm.model, "on endpoint:", llm.endpoint)
async with MCPServerSse(
name="wiki_retriever_mcp",
params={"url": self.mcp_server_url},
) as server:
agent = Agent(
model=LitellmModel(model="hosted_vllm/" + llm.model, base_url=llm.endpoint),
model_settings=ModelSettings(
max_tokens=4096,
temperature=0.7,
),
name="Assistant",
instructions=agent_prompt,
mcp_servers=[server],
)
result = await Runner.run(agent, task["question"])
answer = result.final_output
reward = compute_scores(answer, str(task["answer"]))
print(
"question:{} answer: {} ground_truth: {} reward: {}".format(
task["question"], answer, task["answer"], reward
)
)
return reward

async def validation_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any:
llm: LLM = resources.get("main_llm")
resources = {
"main_llm": LLM(
endpoint=llm.endpoint,
model=llm.model,
sampling_parameters={"temperature": 0.7},
)
}
return await self.training_rollout_async(task, rollout_id, resources)


if __name__ == "__main__":
Trainer(n_workers=12).fit(RAGAgent(), "http://localhost:9999/")
53 changes: 53 additions & 0 deletions examples/rag/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/bin/bash

set -e

export N_GPUS=1
export BASE_MODEL=Qwen/Qwen3-1.7B
export DATA_DIR=data
export ROLLOUT_TP_SIZE=1
export EXPERIMENT_NAME=rag_agent
export PROJECT_NAME=AgentLightning

echo "Starting training script..."

python -m agentlightning.verl \
algorithm.adv_estimator=grpo \
data.train_files=${DATA_DIR}/musique_train.parquet \
data.val_files=${DATA_DIR}/musique_dev_128.parquet \
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
trainer.n_gpus_per_node=${N_GPUS} \
data.train_batch_size=32 \
actor_rollout_ref.rollout.n=4 \
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.multi_turn.format=hermes \
actor_rollout_ref.model.path=${BASE_MODEL} \
data.max_prompt_length=4096 \
data.max_response_length=2048 \
data.truncation='error' \
trainer.val_before_train=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.actor.kl_loss_coef=0.000 \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.clip_ratio_low=0.2 \
actor_rollout_ref.actor.clip_ratio_high=0.3 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name=${PROJECT_NAME} \
trainer.experiment_name=${EXPERIMENT_NAME} \
trainer.nnodes=1 \
trainer.save_freq=40 \
trainer.test_freq=20 \
trainer.total_epochs=2 $@
Loading
Loading