11# Copyright (c) Microsoft. All rights reserved.
22
3- from typing import Any , Dict , cast
3+ import argparse
4+ import os
5+ from datetime import datetime
6+ from typing import Any , Dict , Optional , cast
47
58from calc_agent import MathProblem , calc_agent
69from datasets import Dataset as HuggingFaceDataset
@@ -47,7 +50,7 @@ def verl_default_config() -> Dict[str, Any]:
4750 "fsdp_config" : {"param_offload" : True },
4851 },
4952 "model" : {
50- "path" : "Qwen/Qwen2.5-0 .5B-Instruct" ,
53+ "path" : "Qwen/Qwen2.5-1 .5B-Instruct" ,
5154 "use_remove_padding" : True ,
5255 "enable_gradient_checkpointing" : True ,
5356 },
@@ -57,43 +60,101 @@ def verl_default_config() -> Dict[str, Any]:
5760 "val_before_train" : True ,
5861 "critic_warmup" : 0 ,
5962 "logger" : ["console" , "wandb" ],
60- "project_name" : "AgentLightningCI " ,
61- "experiment_name" : "train_verl_v0_2 " ,
63+ "project_name" : "AgentLightning " ,
64+ "experiment_name" : "calc_x " ,
6265 "nnodes" : 1 ,
63- "save_freq" : 3 ,
64- "test_freq" : 3 ,
65- "total_epochs" : 1 ,
66- "total_training_steps" : 3 ,
66+ "save_freq" : 64 ,
67+ "test_freq" : 32 ,
68+ "total_epochs" : 2 ,
6769 },
6870 }
6971
7072
71- def train (* , train_file : str , val_file : str , model : str , llm_proxy : bool , ci : bool , n_runners : int ):
72- # TODO: use train_file and val_file from arguments
73- train_dataset = cast (agl .Dataset [MathProblem ], HuggingFaceDataset .from_parquet ("data/train.parquet" ).to_list ())
74- val_dataset = cast (agl .Dataset [MathProblem ], HuggingFaceDataset .from_parquet ("data/test_mini.parquet" ).to_list ())
73+ def train (* , train_file : str , val_file : str , model : Optional [str ], llm_proxy : bool , ci : bool , n_runners : int ):
74+ """The training entrypoint function for Calc-X agent with VERL algorithm.
75+
76+ Args:
77+ train_file: The path to the training parquet file.
78+ val_file: The path to the validation parquet file.
79+ model: The HF model id or path to override the default model.
80+ llm_proxy: Whether to enable LLM Proxy tracing/adapter.
81+ ci: Whether to run a minimal CI-style training loop.
82+ n_runners: The number of runners for the Trainer.
83+ """
84+ # Load datasets (respect CLI file paths)
85+ train_dataset = cast (agl .Dataset [MathProblem ], HuggingFaceDataset .from_parquet (train_file ).to_list ())
86+ val_dataset = cast (agl .Dataset [MathProblem ], HuggingFaceDataset .from_parquet (val_file ).to_list ())
7587
7688 print ("First 5 rows of train dataset:" )
7789 print (train_dataset [:5 ]) # type: ignore
7890 print ("First 5 rows of val dataset:" )
7991 print (val_dataset [:5 ]) # type: ignore
8092
8193 config = verl_default_config ()
82- # TODO: augment config based on function arguments
94+
95+ if model :
96+ config ["actor_rollout_ref" ]["model" ]["path" ] = model
97+
98+ # CI toggle keeps everything else the same but you can tweak the lightweight bits here if desired
99+ if ci :
100+ # Config the experiment name and project name so that they are available to CI
101+ timestamp = datetime .now ().strftime ("%Y%m%d%H%M%S" )
102+ EXPERIMENT_NAME = f"calc_x_{ timestamp } "
103+
104+ PROJECT_NAME = "AgentLightningCI"
105+
106+ # Simulate writing to $GITHUB_OUTPUT if it’s set
107+ github_output = os .getenv ("GITHUB_OUTPUT" )
108+ if github_output :
109+ with open (github_output , "a" ) as f :
110+ f .write (f"project_name={ PROJECT_NAME } \n " )
111+ f .write (f"run_name={ EXPERIMENT_NAME } \n " )
112+
113+ print ("Set environment variables:" )
114+ print (f"PROJECT_NAME={ PROJECT_NAME } " )
115+ print (f"EXPERIMENT_NAME={ EXPERIMENT_NAME } " )
116+
117+ # Keep it tiny/light without adding new knobs
118+ config ["actor_rollout_ref" ]["rollout" ]["gpu_memory_utilization" ] = 0.6
119+ config ["trainer" ]["total_epochs" ] = 1
120+ config ["trainer" ]["total_training_steps" ] = 6
121+ config ["trainer" ]["test_freq" ] = 6
122+ config ["trainer" ]["experiment_name" ] = EXPERIMENT_NAME
123+ config ["trainer" ]["project_name" ] = PROJECT_NAME
124+ config ["trainer" ].pop ("save_freq" , None )
83125
84126 algorithm = agl .VERL (config )
85127
86128 if llm_proxy :
87- # We deliberately used a dummy OtelTracer and handles all tracing via LLM Proxy.
88- tracer = agl .OtelTracer ()
129+ tracer = agl .OtelTracer () # dummy tracer for LLM Proxy
89130 adapter = agl .LlmProxyTraceToTriplet ()
90-
91131 trainer = agl .Trainer (algorithm = algorithm , n_runners = n_runners , tracer = tracer , adapter = adapter )
92132 else :
93133 trainer = agl .Trainer (algorithm = algorithm , n_runners = n_runners )
94134
95135 trainer .fit (calc_agent , train_dataset , val_dataset = val_dataset )
96136
97137
138+ def main ():
139+ parser = argparse .ArgumentParser (description = "Train a math calc agent with Agent-lightning + VERL." )
140+ parser .add_argument ("--train-file" , type = str , default = "data/train.parquet" , help = "Path to train parquet file" )
141+ parser .add_argument ("--val-file" , type = str , default = "data/test.parquet" , help = "Path to val parquet file" )
142+ parser .add_argument ("--model" , type = str , default = None , help = "HF model id or path (optional)" )
143+ parser .add_argument ("--llm-proxy" , action = "store_true" , help = "Enable LLM Proxy tracing/adapter" )
144+ parser .add_argument ("--ci" , action = "store_true" , help = "Run a minimal CI-style training loop" )
145+ parser .add_argument ("--n-runners" , type = int , default = 10 , help = "Number of runners for Trainer" )
146+
147+ args = parser .parse_args ()
148+
149+ train (
150+ train_file = args .train_file ,
151+ val_file = args .val_file ,
152+ model = args .model ,
153+ llm_proxy = args .llm_proxy ,
154+ ci = args .ci ,
155+ n_runners = args .n_runners ,
156+ )
157+
158+
98159if __name__ == "__main__" :
99160 main ()
0 commit comments