Skip to content

Commit 90db3cc

Browse files
committed
try not to discriminate between 0.5 and 0.6
1 parent f7c90f3 commit 90db3cc

File tree

3 files changed

+25
-11
lines changed

3 files changed

+25
-11
lines changed

agentlightning/verl/entrypoint.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@ def run_ppo(
4141
) -> None:
4242
if not ray.is_initialized():
4343
# this is for local ray cluster
44-
installed_verl = version("verl")
45-
if packaging_version.parse(installed_verl) >= packaging_version.parse("0.6.0"):
44+
try:
45+
# verl >= 0.6.0
4646
num_cpus = config.ray_kwargs.ray_init.num_cpus
47-
else:
47+
except AttributeError:
48+
# verl < 0.6.0
4849
num_cpus = config.ray_init.num_cpus
4950
ray.init(
5051
runtime_env={

examples/calc_x/train_calc_agent.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ def verl_default_config() -> Dict[str, Any]:
6060
"multi_turn": {"format": "hermes"},
6161
"name": "vllm",
6262
"gpu_memory_utilization": 0.6,
63+
"engine_kwargs": {
64+
"vllm": {
65+
"enable_auto_tool_choice": True,
66+
"tool_call_parser": "hermes",
67+
}
68+
},
6369
},
6470
"actor": {
6571
"ppo_mini_batch_size": 32,
@@ -98,14 +104,14 @@ def verl_default_config() -> Dict[str, Any]:
98104
"total_epochs": 2,
99105
},
100106
}
101-
installed_verl = version("verl")
102-
if packaging_version.parse(installed_verl) >= packaging_version.parse("0.6.0"):
103-
config["actor_rollout_ref"]["rollout"]["engine_kwargs"] = {
104-
"vllm": {
105-
"enable_auto_tool_choice": True,
106-
"tool_call_parser": "hermes",
107-
},
108-
}
107+
# installed_verl = version("verl")
108+
# if packaging_version.parse(installed_verl) >= packaging_version.parse("0.6.0"):
109+
# config["actor_rollout_ref"]["rollout"]["engine_kwargs"] = {
110+
# "vllm": {
111+
# "enable_auto_tool_choice": True,
112+
# "tool_call_parser": "hermes",
113+
# },
114+
# }
109115
return config
110116

111117

examples/spider/train_sql_agent.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@
5252
"multi_turn": {"format": "hermes"},
5353
"name": "vllm",
5454
"gpu_memory_utilization": 0.8,
55+
"engine_kwargs": {
56+
"vllm": {
57+
"enable_auto_tool_choice": True,
58+
"tool_call_parser": "hermes",
59+
}
60+
},
5561
},
5662
"actor": {
5763
"ppo_mini_batch_size": 32,
@@ -139,6 +145,7 @@ def config_train_llama() -> Dict[str, Any]:
139145

140146
config = deepcopy(RL_TRAINING_CONFIG)
141147
config["actor_rollout_ref"]["rollout"]["multi_turn"]["format"] = "llama3_json"
148+
config["actor_rollout_ref"]["rollout"]["engine_kwargs"]["vllm"]["tool_call_parser"] = "llama3_json"
142149
config["actor_rollout_ref"]["model"]["path"] = "meta-llama/Llama-3.2-1B-Instruct"
143150
return config
144151

0 commit comments

Comments
 (0)