Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions recipe/langgraph_agent/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,25 +248,43 @@ async def _postprocess(
content, function_calls = await tool_parser.extract_tool_calls(response_ids)

tool_calls, invalid_tool_calls = [], []

for function_call in function_calls:
try:
args = json.loads(function_call.arguments)
if not isinstance(args, dict):
raise json.JSONDecodeError(f"Invalid json tool arguments: {args}")
tool_call = ToolCall(
args=args,
name=function_call.name,
id=str(uuid.uuid4()),
)
tool_calls.append(tool_call)
except json.JSONDecodeError as e:
logger.warning(f"Invalid json tool arguments: {e}")
tool_call = InvalidToolCall(
args=function_call.arguments,
reason = f"Invalid JSON tool arguments: {e}"
logger.warning(reason)
invalid_tool_calls.append(
InvalidToolCall(
name=function_call.name,
args=function_call.arguments,
id=str(uuid.uuid4()),
error=reason,
)
)
continue

if not isinstance(args, dict):
reason = f"Tool arguments must be a JSON object, got {type(args).__name__}"
logger.warning(reason)
invalid_tool_calls.append(
InvalidToolCall(
name=function_call.name,
args=function_call.arguments,
id=str(uuid.uuid4()),
error=reason,
)
)
continue

tool_calls.append(
ToolCall(
name=function_call.name,
error=f"Invalid json tool arguments: {e}",
args=args,
id=str(uuid.uuid4()),
)
invalid_tool_calls.append(tool_call)
)

message = AIMessage(
content=content,
Expand Down
14 changes: 14 additions & 0 deletions recipe/langgraph_agent/example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ Now, let's prepare two small datasets for training and evaluation:
python recipe/langgraph_agent/example/create_dataset.py
```

- Parameters: `--train_size` (default: 5000), `--test_size` (default: 500), `--output_dir` (default: `data/math_expression_tool`).
- Example with custom sizes/output:
```bash
python recipe/langgraph_agent/example/create_dataset.py \
--train_size 10000 \
--test_size 1000 \
--output_dir data/math_expression_tool
```

Note that dataset should contain a column `agent_name` with `math_expression`, which is used by `AgentLoopWorker` to select the
agent loop class.
| prompt | reward_model | agent_name |
Expand All @@ -65,6 +74,11 @@ Hook all these up and start training:
bash recipe/langgraph_agent/example/run_qwen2.5_3b.sh 2>&1 | tee train.log
```

To submit on a SLURM cluster (the script contains SBATCH headers):
```bash
sbatch recipe/langgraph_agent/example/run_qwen2.5_3b.sh
```

After total 39 steps, model should achieve 100% accuray on test dataset:
- val-aux/lighteval/MATH/reward: 1.0
- val-aux/num_turns/mean: 9.0, average number of messages include assistant and tool turns.
Expand Down
20 changes: 16 additions & 4 deletions recipe/langgraph_agent/example/create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
Create dataset for calculator
"""

import argparse
import os
import random

import pandas as pd
Expand Down Expand Up @@ -265,13 +267,23 @@ def generate_data(total_num_dataset, split):


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Math Expression Dataset Generator")
parser.add_argument("--train_size", type=int, default=5000, help="Number of training samples")
parser.add_argument("--test_size", type=int, default=500, help="Number of testing samples")
parser.add_argument("--output_dir", default="data/math_expression_tool", help="Directory to save the dataset")
args = parser.parse_args()

# print(calculate("3@2")) # Output: 5 (3*3 - 2*2)
# print(calculate("3@2+4")) # Output: 9 (5 + 4)
# print(calculate("3*(4@2)")) # Output: 24 (3 * 8)
# print(calculate("(5@3)*2")) # Output: 18 (9 * 2)

train_dataset = generate_data(total_num_dataset=5000, split="train")
test_dataset = generate_data(total_num_dataset=500, split="test")
train_dataset = generate_data(total_num_dataset=args.train_size, split="train")
test_dataset = generate_data(total_num_dataset=args.test_size, split="test")

# Make sure the dataset directory exists
os.makedirs(args.output_dir, exist_ok=True)

train_dataset.to_parquet("train.parquet")
test_dataset.to_parquet("test.parquet")
# Save the datasets to parquet files
train_dataset.to_parquet(os.path.join(args.output_dir, "train.parquet"))
test_dataset.to_parquet(os.path.join(args.output_dir, "test.parquet"))
94 changes: 70 additions & 24 deletions recipe/langgraph_agent/example/run_qwen2.5_3b.sh
Original file line number Diff line number Diff line change
@@ -1,28 +1,60 @@
set -x
#!/usr/bin/env bash
#SBATCH --job-name=rl-langgraph-3B
#SBATCH --partition=main
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=64
#SBATCH --gres=gpu:4
#SBATCH --mem=0
#SBATCH --time=10:00:00
#SBATCH --output=%x_%j.out
#SBATCH --error=%x_%j.err

set -xeuo pipefail

# ================= cluster topology =================
export GPUS_PER_NODE=${SLURM_GPUS_ON_NODE:-${GPUS_PER_NODE:-1}} # GPUs on this node
NNODES=${SLURM_JOB_NUM_NODES:-${NNODES:-1}}
export NNODES
export RAY_NUM_NODES=$NNODES

# Require at least 2 GPUs
TOTAL_GPUS=$((GPUS_PER_NODE * NNODES))
if [ "$TOTAL_GPUS" -lt 2 ]; then
echo "Error: at least 2 GPUs are required, detected $TOTAL_GPUS." >&2
exit 1
fi

echo "Using $NNODES nodes and $GPUS_PER_NODE GPUs per node..."

# ================= data/model/tool =================
HDFS_ROOT=${HDFS_ROOT:-$PWD}
DATA_ROOT=${DATA_ROOT:-$PWD}

model_path=$DATA_ROOT/model/Qwen2.5-3B-Instruct
# Prefer local model if present, otherwise fall back to HF hub path
model_path=${model_path:-$DATA_ROOT/model/Qwen2.5-3B-Instruct}
if [ ! -d "$model_path" ]; then
model_path=Qwen/Qwen2.5-3B-Instruct
fi

train_files=$DATA_ROOT/dataset/math_expression_tool/train.parquet
test_files=$DATA_ROOT/dataset/math_expression_tool/test.parquet
# Use the default output directory produced by create_dataset.py
train_files=$DATA_ROOT/data/math_expression_tool/train.parquet
test_files=$DATA_ROOT/data/math_expression_tool/test.parquet

# agent
# Agent config
agent_loop_config_path=recipe/langgraph_agent/example/agent.yaml

# wandb
# =================== wandb ===================
project_name=math_expression_tool
experiment_name=qwen2.5-3b
default_local_dir=$DATA_ROOT/checkpoint/$experiment_name

# ================= algorithm =================
adv_estimator=grpo

use_kl_in_reward=False
use_kl_in_reward=false
kl_coef=0.0
use_kl_loss=False
use_kl_loss=false
kl_loss_coef=0.0

clip_ratio_low=0.2
Expand All @@ -38,36 +70,50 @@ ppo_mini_batch_size=16
n_resp_per_prompt=8
n_resp_per_prompt_val=1

# ================= perfomance =================
infer_tp=2 # vllm
train_sp=4 # train
offload=True
# =================== logging ===================
export RAY_LOGGING_LEVEL=DEBUG
export HYDRA_FULL_ERROR=1

# ================= performance =================
export NCCL_IBEXT_DISABLE=1
export NCCL_NVLS_ENABLE=1
export NCCL_IB_HCA=mlx5
export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The UCX_NET_DEVICES variable is hardcoded to a specific configuration with 8 Mellanox network interfaces. This will cause the script to fail on most systems that do not match this exact hardware setup, which undermines the goal of making the example runnable out-of-the-box. To ensure portability, this line should be removed to allow UCX to use default settings, or the device list should be generated dynamically based on the available hardware.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This setting is copied from the GSPO recipe (see test_gspo_3b_math.sh:L26) and has worked across multiple hardware configs in my tests. I’m not a UCX expert and open to feedback if there’s a more portable approach.

export VLLM_USE_V1=1
export VLLM_ATTENTION_BACKEND=FLASH_ATTN

infer_tp=2 # vLLM tensor parallel size
train_sp=4 # Ulysses sequence parallel size for actor
offload=true

actor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 4 ))
log_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 2 ))

train_files="['$train_files']"
test_files="['$test_files']"

python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=$adv_estimator \
algorithm.use_kl_in_reward=$use_kl_in_reward \
algorithm.kl_ctrl.kl_coef=$kl_coef \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.return_raw_chat=True \
data.return_raw_chat=true \
data.train_batch_size=$train_batch_size \
data.max_prompt_length=$max_prompt_length \
data.max_response_length=$max_response_length \
data.filter_overlong_prompts=True \
data.filter_overlong_prompts=true \
data.truncation='error' \
actor_rollout_ref.model.path=$model_path \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.model.path="$model_path" \
actor_rollout_ref.model.use_remove_padding=true \
actor_rollout_ref.model.enable_gradient_checkpointing=true \
actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \
actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \
actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \
actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.actor.optim.lr=$actor_lr \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.use_dynamic_bsz=true \
actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \
Expand All @@ -86,14 +132,14 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \
actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \
actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \
trainer.logger=['console','wandb'] \
trainer.logger='["console","wandb"]' \
trainer.project_name=$project_name \
trainer.experiment_name=$experiment_name \
trainer.n_gpus_per_node=$ARNOLD_WORKER_GPU \
trainer.val_before_train=True \
trainer.n_gpus_per_node="$GPUS_PER_NODE" \
trainer.val_before_train=true \
trainer.log_val_generations=50 \
trainer.nnodes=$ARNOLD_WORKER_NUM \
trainer.nnodes="$NNODES" \
trainer.save_freq=-1 \
trainer.default_local_dir=$default_local_dir \
trainer.default_local_dir="$default_local_dir" \
trainer.test_freq=5 \
trainer.total_epochs=1 $@
trainer.total_epochs=1 "$@"