Skip to content

Commit b80ae16

Browse files
vermouth1992techkang
authored andcommitted
[rollout] feat: add rollout config (volcengine#3010)
### What does this PR do? - Add rollout config ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent 97361ce commit b80ae16

File tree

14 files changed

+211
-38
lines changed

14 files changed

+211
-38
lines changed

tests/utils/test_config_on_cpu.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,20 @@
1717

1818
from omegaconf import OmegaConf
1919

20+
from verl.base_config import BaseConfig
2021
from verl.utils import omega_conf_to_dataclass
2122

2223

2324
@dataclass
24-
class TestDataclass:
25-
hidden_size: int
26-
activation: str
25+
class TestDataclass(BaseConfig):
26+
hidden_size: int = 0
27+
activation: str = "relu"
2728

2829

2930
@dataclass
30-
class TestTrainConfig:
31-
batch_size: int
32-
model: TestDataclass
31+
class TestTrainConfig(BaseConfig):
32+
batch_size: int = 0
33+
model: TestDataclass = field(default_factory=TestDataclass)
3334
override_config: dict = field(default_factory=dict)
3435

3536

@@ -79,7 +80,7 @@ def test_command_with_override(self):
7980

8081
# Run the command
8182
result = subprocess.run(
82-
["python3", "scripts/print_cfg.py", "+critic.profiler.extra.any_key=val"],
83+
["python3", "scripts/print_cfg.py"],
8384
capture_output=True,
8485
text=True,
8586
)
@@ -90,7 +91,6 @@ def test_command_with_override(self):
9091
# Verify the output contains expected config information
9192
self.assertIn("critic", result.stdout)
9293
self.assertIn("profiler", result.stdout)
93-
self.assertIn("extra={'any_key': 'val'}", result.stdout)
9494

9595

9696
if __name__ == "__main__":

tests/utils/test_nvtx_profile.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_frozen_config(self):
5656
from verl.utils.profiler.config import ProfilerConfig
5757

5858
# Create a new ProfilerConfig instance
59-
config = ProfilerConfig(all_ranks=False, ranks=[0], extra={"key": "value"})
59+
config = ProfilerConfig(all_ranks=False, ranks=[0])
6060

6161
with self.assertRaises(FrozenInstanceError):
6262
config.all_ranks = True
@@ -70,10 +70,6 @@ def test_frozen_config(self):
7070
with self.assertRaises(TypeError):
7171
config["ranks"] = [1, 2, 3]
7272

73-
assert config["extra"]["key"] == "value"
74-
config["extra"]["key"] = "value2"
75-
assert config["extra"]["key"] == "value2"
76-
7773

7874
class TestNsightSystemsProfiler(unittest.TestCase):
7975
"""Test suite for NsightSystemsProfiler functionality.

verl/base_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import collections
16-
from dataclasses import FrozenInstanceError, dataclass, field, fields
16+
from dataclasses import FrozenInstanceError, dataclass, fields
1717
from typing import Any
1818

1919

@@ -27,8 +27,8 @@ class BaseConfig(collections.abc.Mapping):
2727
This allows instances of this class to be used like dictionaries.
2828
"""
2929

30-
_mutable_fields = {"extra"}
31-
extra: dict[str, Any] = field(default_factory=dict)
30+
_mutable_fields = set()
31+
_target_: str = ""
3232

3333
def __setattr__(self, name: str, value):
3434
"""Set the value of an attribute. Check if the attr is mutable before setting the value."""

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ actor_rollout_ref:
138138
use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False}
139139
load_weight: true
140140
rollout:
141+
_target_: verl.workers.config.RolloutConfig
141142
name: ???
142143
mode: sync
143144
temperature: 1.0
@@ -170,12 +171,14 @@ actor_rollout_ref:
170171
sglang:
171172
attention_backend: null
172173
val_kwargs:
174+
_target_: verl.workers.config.SamplingConfig
173175
top_k: -1
174176
top_p: 1.0
175177
temperature: 0
176178
'n': 1
177179
do_sample: false
178180
multi_turn:
181+
_target_: verl.workers.config.MultiTurnConfig
179182
enable: false
180183
max_assistant_turns: null
181184
tool_config_path: null
@@ -189,13 +192,16 @@ actor_rollout_ref:
189192
format: hermes
190193
calculate_log_probs: false
191194
agent:
195+
_target_: verl.workers.config.AgentLoopConfig
192196
num_workers: 8
193197
agent_loop_config_path: null
194198
custom_async_server:
199+
_target_: verl.workers.config.CustomAsyncServerConfig
195200
path: null
196201
name: null
197202
update_weights_bucket_megabytes: 512
198203
trace:
204+
_target_: verl.workers.config.TraceConfig
199205
backend: null
200206
token2text: false
201207
skip_rollout: false

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ actor_rollout_ref:
113113
entropy_from_logits_with_chunking: false
114114
entropy_checkpointing: false
115115
rollout:
116+
_target_: verl.workers.config.RolloutConfig
116117
name: ???
117118
mode: sync
118119
temperature: 1.0
@@ -145,12 +146,14 @@ actor_rollout_ref:
145146
sglang:
146147
attention_backend: null
147148
val_kwargs:
149+
_target_: verl.workers.config.SamplingConfig
148150
top_k: -1
149151
top_p: 1.0
150152
temperature: 0
151153
'n': 1
152154
do_sample: false
153155
multi_turn:
156+
_target_: verl.workers.config.MultiTurnConfig
154157
enable: false
155158
max_assistant_turns: null
156159
tool_config_path: null
@@ -164,13 +167,16 @@ actor_rollout_ref:
164167
format: hermes
165168
calculate_log_probs: false
166169
agent:
170+
_target_: verl.workers.config.AgentLoopConfig
167171
num_workers: 8
168172
agent_loop_config_path: null
169173
custom_async_server:
174+
_target_: verl.workers.config.CustomAsyncServerConfig
170175
path: null
171176
name: null
172177
update_weights_bucket_megabytes: 512
173178
trace:
179+
_target_: verl.workers.config.TraceConfig
174180
backend: null
175181
token2text: false
176182
skip_rollout: false

verl/trainer/config/generation.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ model:
1414
path: ~/models/Qwen2-7B-Instruct
1515
external_lib: null
1616
rollout:
17+
_target_: verl.workers.config.RolloutConfig
1718
name: vllm
1819
mode: sync # sync: LLM, async: AsyncLLM
1920
temperature: 1.0

verl/trainer/config/rollout/rollout.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Target class for this configuration
2+
_target_: verl.workers.config.RolloutConfig
3+
14
# actor_rollout_ref.rollout.name: hf/vllm/sglang. The default value will be removed in the future
25
name: ???
36

@@ -103,6 +106,9 @@ engine_kwargs:
103106
# Sampling parameters used during validation.
104107
val_kwargs:
105108

109+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
110+
_target_: verl.workers.config.SamplingConfig
111+
106112
# sampling parameters for validation
107113
# Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout.
108114
top_k: -1
@@ -122,6 +128,9 @@ val_kwargs:
122128
# Multi-turn interaction config for tools or chat.
123129
multi_turn:
124130

131+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
132+
_target_: verl.workers.config.MultiTurnConfig
133+
125134
# set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well
126135
enable: False
127136

@@ -170,6 +179,9 @@ calculate_log_probs: False
170179
# [Experimental] agent loop based rollout configs
171180
agent:
172181

182+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
183+
_target_: verl.workers.config.AgentLoopConfig
184+
173185
# Number of agent loop workers
174186
num_workers: 8
175187

@@ -188,6 +200,9 @@ agent:
188200
# custom async server configs
189201
custom_async_server:
190202

203+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
204+
_target_: verl.workers.config.CustomAsyncServerConfig
205+
191206
# Path to the custom async server implementation
192207
path: null
193208

@@ -211,6 +226,9 @@ update_weights_bucket_megabytes: 512
211226
# trace rollout data
212227
trace:
213228

229+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
230+
_target_: verl.workers.config.TraceConfig
231+
214232
# trace backend, support mlflow, weave
215233
backend: null
216234

verl/utils/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,10 @@ def omega_conf_to_dataclass(config: DictConfig | dict, dataclass_type: Optional[
5353
raise ValueError(f"{dataclass_type} must be a dataclass")
5454
cfg = OmegaConf.create(config) # in case it's a dict
5555
# pop _target_ to avoid hydra instantiate error, as most dataclass do not have _target_
56-
if "_target_" in cfg:
57-
cfg.pop("_target_")
56+
# Updated (vermouth1992) We add _target_ to BaseConfig so that it is compatible.
57+
# Otherwise, this code path can't support recursive instantiation.
58+
# if "_target_" in cfg:
59+
# cfg.pop("_target_")
5860
cfg_from_dataclass = OmegaConf.structured(dataclass_type)
5961
# let cfg override the existing vals in `cfg_from_dataclass`
6062
cfg_merged = OmegaConf.merge(cfg_from_dataclass, cfg)

verl/workers/config/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .actor import * # noqa
1717
from .engine import * # noqa
1818
from .optimizer import * # noqa
19-
from . import actor, critic, engine, optimizer
19+
from .rollout import * # noqa
20+
from . import actor, critic, engine, optimizer, rollout
2021

21-
__all__ = actor.__all__ + critic.__all__ + engine.__all__ + optimizer.__all__
22+
__all__ = actor.__all__ + critic.__all__ + engine.__all__ + optimizer.__all__ + rollout.__all__

verl/workers/config/rollout.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass, field
16+
from typing import Optional
17+
18+
from verl.base_config import BaseConfig
19+
from verl.utils.profiler import ProfilerConfig
20+
21+
__all__ = [
22+
"SamplingConfig",
23+
"MultiTurnConfig",
24+
"CustomAsyncServerConfig",
25+
"AgentLoopConfig",
26+
"TraceConfig",
27+
"RolloutConfig",
28+
]
29+
30+
31+
@dataclass
32+
class SamplingConfig(BaseConfig):
33+
temperature: float = 1.0
34+
top_k: int = -1
35+
top_p: float = 1.0
36+
do_sample: bool = True
37+
n: int = 1
38+
39+
40+
@dataclass
41+
class MultiTurnConfig(BaseConfig):
42+
_mutable_fields = {"max_assistant_turns", "max_user_turns"}
43+
44+
enable: bool = False
45+
max_assistant_turns: Optional[int] = None
46+
tool_config_path: Optional[str] = None
47+
max_user_turns: Optional[int] = None
48+
max_parallel_calls: int = 1
49+
max_tool_response_length: int = 256
50+
tool_response_truncate_side: str = "middle"
51+
interaction_config_path: Optional[str] = None
52+
use_inference_chat_template: bool = False
53+
tokenization_sanity_check_mode: str = "strict"
54+
format: str = "hermes"
55+
56+
57+
@dataclass
58+
class CustomAsyncServerConfig(BaseConfig):
59+
path: Optional[str] = None
60+
name: Optional[str] = None
61+
62+
63+
@dataclass
64+
class AgentLoopConfig(BaseConfig):
65+
num_workers: int = 8
66+
agent_loop_config_path: Optional[str] = None
67+
custom_async_server: CustomAsyncServerConfig = field(default_factory=CustomAsyncServerConfig)
68+
69+
70+
@dataclass
71+
class TraceConfig(BaseConfig):
72+
backend: Optional[str] = None
73+
token2text: bool = False
74+
75+
76+
@dataclass
77+
class RolloutConfig(BaseConfig):
78+
_mutable_fields = {"max_model_len"}
79+
80+
name: Optional[str] = None
81+
mode: str = "sync"
82+
83+
temperature: float = 1.0
84+
top_k: int = -1
85+
top_p: float = 1.0
86+
do_sample: bool = True
87+
n: int = 1
88+
89+
prompt_length: int = 512
90+
response_length: int = 512
91+
92+
dtype: str = "bfloat16"
93+
gpu_memory_utilization: float = 0.5
94+
ignore_eos: bool = False
95+
enforce_eager: bool = True
96+
cudagraph_capture_sizes: Optional[list] = None
97+
free_cache_engine: bool = True
98+
tensor_model_parallel_size: int = 2
99+
max_num_batched_tokens: int = 8192
100+
101+
# TODO: enable train_kwargs
102+
# train_sampling_config: SamplingConfig = field(default_factory=SamplingConfig)
103+
104+
val_kwargs: SamplingConfig = field(default_factory=SamplingConfig)
105+
106+
max_model_len: Optional[int] = None
107+
max_num_seqs: int = 1024
108+
109+
# note that the logprob computation should belong to the actor
110+
log_prob_micro_batch_size: Optional[int] = None
111+
log_prob_micro_batch_size_per_gpu: Optional[int] = None
112+
log_prob_use_dynamic_bsz: bool = False
113+
log_prob_max_token_len_per_gpu: int = 16384
114+
115+
disable_log_stats: bool = True
116+
117+
multi_stage_wake_up: bool = False
118+
engine_kwargs: dict = field(default_factory=dict)
119+
120+
calculate_log_probs: bool = False
121+
122+
agent: AgentLoopConfig = field(default_factory=AgentLoopConfig)
123+
124+
trace: TraceConfig = field(default_factory=TraceConfig)
125+
126+
multi_turn: MultiTurnConfig = field(default_factory=MultiTurnConfig)
127+
128+
update_weights_bucket_megabytes: int = 512
129+
130+
skip_rollout: bool = False
131+
132+
skip_dump_dir: str = "/tmp/rollout_dump"
133+
134+
profiler: ProfilerConfig = field(default_factory=ProfilerConfig)
135+
136+
enable_chunked_prefill: bool = True
137+
load_format: str = "dummy_dtensor"
138+
139+
layered_summon: bool = False
140+
141+
layer_name_map: dict = field(default_factory=dict)

0 commit comments

Comments
 (0)