Skip to content

Commit ce2b050

Browse files
committed
update train_calc_agent script
1 parent 2ff60aa commit ce2b050

File tree

2 files changed

+87
-26
lines changed

2 files changed

+87
-26
lines changed

.github/workflows/examples.yml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,9 @@ jobs:
164164
set -ex
165165
. .venv/bin/activate
166166
cd examples/calc_x
167-
python calc_agent_dev.py
167+
python legacy_calc_agent_debug.py
168168
env:
169-
OPENAI_API_BASE: http://localhost:12306/
169+
OPENAI_BASE_URL: http://localhost:12306/
170170
OPENAI_API_KEY: dummy
171171

172172
# Calc-X training suddenly works after running the sanity check.
@@ -180,14 +180,14 @@ jobs:
180180
cd examples/calc_x
181181
../../scripts/restart_ray.sh
182182
sleep 5
183-
PYTHONUNBUFFERED=1 python calc_agent.py &
184-
bash train_ci.sh
185-
pkill -f calc_agent.py && echo "SIGTERM sent to calc_agent.py" || echo "No calc_agent.py process found"
186-
while pgrep -f calc_agent.py; do
187-
echo "Waiting for calc_agent.py to finish..."
183+
PYTHONUNBUFFERED=1 python legacy_calc_agent.py &
184+
bash legacy_train.sh
185+
pkill -f legacy_calc_agent.py && echo "SIGTERM sent to legacy_calc_agent.py" || echo "No legacy_calc_agent.py process found"
186+
while pgrep -f legacy_calc_agent.py; do
187+
echo "Waiting for legacy_calc_agent.py to finish..."
188188
sleep 5
189189
done
190-
echo "calc_agent.py has finished."
190+
echo "legacy_calc_agent.py has finished."
191191
sleep 10
192192
shell: bash
193193
env:
@@ -212,7 +212,7 @@ jobs:
212212
cd examples/calc_x
213213
../../scripts/restart_ray.sh
214214
sleep 5
215-
PYTHONUNBUFFERED=1 python calc_agent_v0_2.py
215+
PYTHONUNBUFFERED=1 python train_calc_agent.py --val-file data/test_mini.parquet --ci
216216
sleep 10
217217
shell: bash
218218
env:
@@ -228,7 +228,7 @@ jobs:
228228
cd examples/calc_x
229229
../../scripts/restart_ray.sh
230230
sleep 5
231-
PYTHONUNBUFFERED=1 python calc_agent_v0_2_llm_proxy.py
231+
PYTHONUNBUFFERED=1 python train_calc_agent.py --val-file data/test_mini.parquet --ci --llm-proxy
232232
sleep 10
233233
shell: bash
234234
env:

examples/calc_x/train_calc_agent.py

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

3-
from typing import Any, Dict, cast
3+
import argparse
4+
import os
5+
from datetime import datetime
6+
from typing import Any, Dict, Optional, cast
47

58
from calc_agent import MathProblem, calc_agent
69
from datasets import Dataset as HuggingFaceDataset
@@ -47,7 +50,7 @@ def verl_default_config() -> Dict[str, Any]:
4750
"fsdp_config": {"param_offload": True},
4851
},
4952
"model": {
50-
"path": "Qwen/Qwen2.5-0.5B-Instruct",
53+
"path": "Qwen/Qwen2.5-1.5B-Instruct",
5154
"use_remove_padding": True,
5255
"enable_gradient_checkpointing": True,
5356
},
@@ -57,43 +60,101 @@ def verl_default_config() -> Dict[str, Any]:
5760
"val_before_train": True,
5861
"critic_warmup": 0,
5962
"logger": ["console", "wandb"],
60-
"project_name": "AgentLightningCI",
61-
"experiment_name": "train_verl_v0_2",
63+
"project_name": "AgentLightning",
64+
"experiment_name": "calc_x",
6265
"nnodes": 1,
63-
"save_freq": 3,
64-
"test_freq": 3,
65-
"total_epochs": 1,
66-
"total_training_steps": 3,
66+
"save_freq": 64,
67+
"test_freq": 32,
68+
"total_epochs": 2,
6769
},
6870
}
6971

7072

71-
def train(*, train_file: str, val_file: str, model: str, llm_proxy: bool, ci: bool, n_runners: int):
72-
# TODO: use train_file and val_file from arguments
73-
train_dataset = cast(agl.Dataset[MathProblem], HuggingFaceDataset.from_parquet("data/train.parquet").to_list())
74-
val_dataset = cast(agl.Dataset[MathProblem], HuggingFaceDataset.from_parquet("data/test_mini.parquet").to_list())
73+
def train(*, train_file: str, val_file: str, model: Optional[str], llm_proxy: bool, ci: bool, n_runners: int):
74+
"""The training entrypoint function for Calc-X agent with VERL algorithm.
75+
76+
Args:
77+
train_file: The path to the training parquet file.
78+
val_file: The path to the validation parquet file.
79+
model: The HF model id or path to override the default model.
80+
llm_proxy: Whether to enable LLM Proxy tracing/adapter.
81+
ci: Whether to run a minimal CI-style training loop.
82+
n_runners: The number of runners for the Trainer.
83+
"""
84+
# Load datasets (respect CLI file paths)
85+
train_dataset = cast(agl.Dataset[MathProblem], HuggingFaceDataset.from_parquet(train_file).to_list())
86+
val_dataset = cast(agl.Dataset[MathProblem], HuggingFaceDataset.from_parquet(val_file).to_list())
7587

7688
print("First 5 rows of train dataset:")
7789
print(train_dataset[:5]) # type: ignore
7890
print("First 5 rows of val dataset:")
7991
print(val_dataset[:5]) # type: ignore
8092

8193
config = verl_default_config()
82-
# TODO: augment config based on function arguments
94+
95+
if model:
96+
config["actor_rollout_ref"]["model"]["path"] = model
97+
98+
# CI toggle keeps everything else the same but you can tweak the lightweight bits here if desired
99+
if ci:
100+
# Config the experiment name and project name so that they are available to CI
101+
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
102+
EXPERIMENT_NAME = f"calc_x_{timestamp}"
103+
104+
PROJECT_NAME = "AgentLightningCI"
105+
106+
# Simulate writing to $GITHUB_OUTPUT if it’s set
107+
github_output = os.getenv("GITHUB_OUTPUT")
108+
if github_output:
109+
with open(github_output, "a") as f:
110+
f.write(f"project_name={PROJECT_NAME}\n")
111+
f.write(f"run_name={EXPERIMENT_NAME}\n")
112+
113+
print("Set environment variables:")
114+
print(f"PROJECT_NAME={PROJECT_NAME}")
115+
print(f"EXPERIMENT_NAME={EXPERIMENT_NAME}")
116+
117+
# Keep it tiny/light without adding new knobs
118+
config["actor_rollout_ref"]["rollout"]["gpu_memory_utilization"] = 0.6
119+
config["trainer"]["total_epochs"] = 1
120+
config["trainer"]["total_training_steps"] = 6
121+
config["trainer"]["test_freq"] = 6
122+
config["trainer"]["experiment_name"] = EXPERIMENT_NAME
123+
config["trainer"]["project_name"] = PROJECT_NAME
124+
config["trainer"].pop("save_freq", None)
83125

84126
algorithm = agl.VERL(config)
85127

86128
if llm_proxy:
87-
# We deliberately used a dummy OtelTracer and handles all tracing via LLM Proxy.
88-
tracer = agl.OtelTracer()
129+
tracer = agl.OtelTracer() # dummy tracer for LLM Proxy
89130
adapter = agl.LlmProxyTraceToTriplet()
90-
91131
trainer = agl.Trainer(algorithm=algorithm, n_runners=n_runners, tracer=tracer, adapter=adapter)
92132
else:
93133
trainer = agl.Trainer(algorithm=algorithm, n_runners=n_runners)
94134

95135
trainer.fit(calc_agent, train_dataset, val_dataset=val_dataset)
96136

97137

138+
def main():
139+
parser = argparse.ArgumentParser(description="Train a math calc agent with Agent-lightning + VERL.")
140+
parser.add_argument("--train-file", type=str, default="data/train.parquet", help="Path to train parquet file")
141+
parser.add_argument("--val-file", type=str, default="data/test.parquet", help="Path to val parquet file")
142+
parser.add_argument("--model", type=str, default=None, help="HF model id or path (optional)")
143+
parser.add_argument("--llm-proxy", action="store_true", help="Enable LLM Proxy tracing/adapter")
144+
parser.add_argument("--ci", action="store_true", help="Run a minimal CI-style training loop")
145+
parser.add_argument("--n-runners", type=int, default=10, help="Number of runners for Trainer")
146+
147+
args = parser.parse_args()
148+
149+
train(
150+
train_file=args.train_file,
151+
val_file=args.val_file,
152+
model=args.model,
153+
llm_proxy=args.llm_proxy,
154+
ci=args.ci,
155+
n_runners=args.n_runners,
156+
)
157+
158+
98159
if __name__ == "__main__":
99160
main()

0 commit comments

Comments
 (0)