Skip to content

Commit a4a20b3

Browse files
vermouth1992gemini-code-assist[bot]
authored andcommitted
[BREAKING] [rollout] feat: add a separate rollout worker (volcengine#3071)
### What does this PR do? - Introduce a separate rolloutworker that can be instantiated without hybridengine - Introduce a ModelConfig that wraps all model related config - Remove hf_rollout (will replace with TP support in the future if needed) - Next PR: modify MegatronWorker to use separate rollout worker ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent c01d56c commit a4a20b3

File tree

10 files changed

+439
-82
lines changed

10 files changed

+439
-82
lines changed

.github/workflows/e2e_ppo_trainer.yml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,15 @@ jobs:
101101
ray stop --force
102102
python3 examples/data_preprocess/gsm8k.py
103103
# HF sanity
104-
- name: Running GSM8K E2E training tests on 1 L20 GPU with hf for sanity
105-
run: |
106-
ray stop --force
107-
bash tests/special_e2e/ppo_trainer/run_single_gpu.sh
108-
# HF sanity
109-
- name: Running GSM8K E2E training tests on 1 L20 GPU with engine interface for sanity.
110-
run: |
111-
ray stop --force
112-
bash tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh
104+
# - name: Running GSM8K E2E training tests on 1 L20 GPU with hf for sanity
105+
# run: |
106+
# ray stop --force
107+
# bash tests/special_e2e/ppo_trainer/run_single_gpu.sh
108+
# # HF sanity
109+
# - name: Running GSM8K E2E training tests on 1 L20 GPU with engine interface for sanity.
110+
# run: |
111+
# ray stop --force
112+
# bash tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh
113113
# Function RM
114114
- name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm with validation and saving (FSDP_SIZE=8)
115115
run: |

docs/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ verl is fast with:
127127
amd_tutorial/amd_build_dockerfile_page.rst
128128
amd_tutorial/amd_vllm_page.rst
129129
ascend_tutorial/ascend_quick_start.rst
130-
ascend_tutorial/ascend_profiling.rst
130+
ascend_tutorial/ascend_profiling_zh.rst
131131
ascend_tutorial/ascend_profiling_en.rst
132132

133133
.. toctree::

recipe/one_step_off_policy/fsdp_workers.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
from verl.single_controller.base import Worker
2727
from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register
2828
from verl.utils import hf_processor, hf_tokenizer, omega_conf_to_dataclass
29-
from verl.utils.debug import DistProfiler, DistProfilerExtension, log_gpu_memory_usage
3029
from verl.utils.device import (
30+
get_device_id,
3131
get_device_name,
3232
get_nccl_backend,
3333
get_torch_device,
@@ -38,7 +38,8 @@
3838
)
3939
from verl.utils.import_utils import import_external_libs
4040
from verl.utils.model import get_generation_config, update_model_config
41-
from verl.utils.profiler import ProfilerConfig
41+
from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer
42+
from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max
4243
from verl.workers.fsdp_workers import ActorRolloutRefWorker as ARRWorker
4344
from verl.workers.fsdp_workers import CriticWorker
4445

@@ -231,8 +232,50 @@ def init_model(self):
231232
self.rollout_sharding_manager = rollout_sharding_manager
232233

233234
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout"), blocking=False)
234-
def async_generate_sequences(self, *args, **kwargs):
235-
return super().generate_sequences(*args, **kwargs)
235+
def async_generate_sequences(self, prompts):
236+
# Support all hardwares
237+
prompts = prompts.to(get_device_id())
238+
239+
assert self._is_rollout
240+
241+
meta_info = {
242+
"eos_token_id": self.generation_config.eos_token_id
243+
if self.generation_config is not None
244+
else self.tokenizer.eos_token_id,
245+
"pad_token_id": self.generation_config.pad_token_id
246+
if self.generation_config is not None
247+
else self.tokenizer.pad_token_id,
248+
}
249+
prompts.meta_info.update(meta_info)
250+
timing_generate = {}
251+
with self.rollout_sharding_manager:
252+
log_gpu_memory_usage("After entering rollout sharding manager", logger=logger)
253+
254+
with simple_timer("generate_sequences", timing_generate):
255+
output = self.rollout.generate_sequences(prompts=prompts)
256+
257+
log_gpu_memory_usage("After rollout generation", logger=logger)
258+
259+
timing_generate.update(self.rollout_sharding_manager.timing)
260+
# We calculate the average timing across all ranks
261+
# to make sure meta_info["timing"] is the same
262+
timing_generate_topk_ratio, timing_generate_min, timing_generate_max = topk_reduce_ratio_min_max(
263+
timing_generate["generate_sequences"]
264+
)
265+
timing_generate = reduce_timing(timing_generate)
266+
timing_generate.update(
267+
{
268+
"generation_timing/max": timing_generate_max,
269+
"generation_timing/min": timing_generate_min,
270+
"generation_timing/topk_ratio": timing_generate_topk_ratio,
271+
}
272+
)
273+
output.meta_info["timing"] = timing_generate
274+
output = output.to("cpu")
275+
276+
# clear kv cache
277+
get_torch_device().empty_cache()
278+
return output
236279

237280
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
238281
def set_actor_weights_info(self, weights_info):
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import ray
17+
18+
from verl.single_controller.base.decorator import Dispatch, register
19+
from verl.single_controller.base.worker import Worker
20+
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
21+
22+
23+
class TestActor(Worker):
24+
# TODO: pass *args and **kwargs is bug prone and not very convincing
25+
def __init__(self, x) -> None:
26+
super().__init__()
27+
self.a = x
28+
29+
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
30+
def get(self):
31+
return self.a + self.rank
32+
33+
34+
class TestHighLevelActor(Worker):
35+
def __init__(self, x=None) -> None:
36+
super().__init__()
37+
self.test_actor = TestActor(x=x)
38+
39+
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
40+
def get(self):
41+
return self.test_actor.get()
42+
43+
44+
def test_nested_worker():
45+
ray.init(num_cpus=100)
46+
47+
# create 4 workers, each hold a GPU
48+
resource_pool = RayResourcePool([4], use_gpu=True)
49+
class_with_args = RayClassWithInitArgs(cls=ray.remote(TestActor), x=2)
50+
51+
worker_group = RayWorkerGroup(
52+
resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic"
53+
)
54+
55+
output = worker_group.get()
56+
57+
assert output == [2, 3, 4, 5]
58+
59+
class_with_args = RayClassWithInitArgs(cls=ray.remote(TestHighLevelActor), x=2)
60+
high_level_worker_group = RayWorkerGroup(
61+
resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic_2"
62+
)
63+
64+
output_1 = high_level_worker_group.get()
65+
66+
assert output_1 == [2, 3, 4, 5]
67+
68+
ray.shutdown()

verl/workers/config/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .engine import * # noqa
1818
from .optimizer import * # noqa
1919
from .rollout import * # noqa
20-
from . import actor, critic, engine, optimizer, rollout
20+
from .model import * # noqa
21+
from . import actor, critic, engine, optimizer, rollout, model
2122

22-
__all__ = actor.__all__ + critic.__all__ + engine.__all__ + optimizer.__all__ + rollout.__all__
23+
__all__ = actor.__all__ + critic.__all__ + engine.__all__ + optimizer.__all__ + rollout.__all__ + model.__all__

verl/workers/config/model.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass, field
16+
from typing import Any, Optional
17+
18+
from omegaconf import MISSING
19+
from transformers import AutoConfig
20+
21+
from verl.base_config import BaseConfig
22+
from verl.utils import hf_processor, hf_tokenizer
23+
from verl.utils.fs import copy_to_local
24+
from verl.utils.model import get_generation_config, update_model_config
25+
26+
__all__ = ["HFModelConfig"]
27+
28+
29+
@dataclass
30+
class HFModelConfig(BaseConfig):
31+
# note that we separate model_path, model_config_path and tokenizer_path in case they are different
32+
_mutable_fields = {
33+
"hf_config_path",
34+
"tokenizer_path",
35+
"hf_config",
36+
"generation_config",
37+
"tokenizer",
38+
"processor",
39+
"local_path",
40+
}
41+
42+
path: str = MISSING
43+
local_path: Optional[str] = None
44+
hf_config_path: Optional[str] = None
45+
tokenizer_path: Optional[str] = None
46+
47+
hf_config: Any = None
48+
generation_config: Any = None
49+
tokenizer: Any = None
50+
processor: Any = None
51+
52+
# whether to use shared memory
53+
use_shm: bool = False
54+
trust_remote_code: bool = False
55+
56+
# custom chat template for the model
57+
custom_chat_template: Optional[str] = None
58+
59+
external_lib: Optional[str] = None
60+
61+
override_config: dict = field(default_factory=dict)
62+
63+
enable_gradient_checkpointing: bool = True
64+
enable_activation_offload: bool = False
65+
66+
use_remove_padding: bool = False
67+
68+
# lora related. We may setup a separate config later
69+
lora_rank: int = 0
70+
lora_alpha: int = 16
71+
target_modules: Optional[str] = "all-linear"
72+
73+
exclude_modules: Optional[str] = None
74+
use_liger: bool = False
75+
76+
use_fused_kernels: bool = False
77+
fused_kernel_options: dict = field(default_factory=dict)
78+
79+
def __post_init__(self):
80+
if self.hf_config_path is None:
81+
self.hf_config_path = self.path
82+
if self.tokenizer_path is None:
83+
self.tokenizer_path = self.path
84+
85+
# constuct tokenizer
86+
self.local_path = copy_to_local(self.path, use_shm=self.use_shm)
87+
self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=self.trust_remote_code)
88+
self.processor = hf_processor(self.local_path, trust_remote_code=self.trust_remote_code)
89+
90+
self.generation_config = get_generation_config(self.hf_config_path, trust_remote_code=self.trust_remote_code)
91+
92+
# constuct hf_config
93+
attn_implementation = self.override_config.get("attn_implementation", "flash_attention_2")
94+
self.hf_config = AutoConfig.from_pretrained(
95+
self.hf_config_path, trust_remote_code=self.trust_remote_code, attn_implementation=attn_implementation
96+
)
97+
98+
override_config_kwargs = {
99+
"bos_token_id": self.tokenizer.bos_token_id,
100+
"eos_token_id": self.tokenizer.eos_token_id,
101+
"pad_token_id": self.tokenizer.pad_token_id,
102+
}
103+
override_config_kwargs.update(self.override_config)
104+
update_model_config(self.hf_config, override_config_kwargs=override_config_kwargs)
105+
106+
# per model patch
107+
if getattr(self.hf_config, "model_type", None) == "kimi_vl":
108+
self.hf_config.text_config.topk_method = "greedy"
109+
110+
def get_processor(self):
111+
return self.processor if self.processor is not None else self.tokenizer

verl/workers/config/rollout.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from dataclasses import dataclass, field
1616
from typing import Optional
1717

18+
from omegaconf import MISSING
19+
1820
from verl.base_config import BaseConfig
1921
from verl.utils.profiler import ProfilerConfig
2022

@@ -77,7 +79,7 @@ class TraceConfig(BaseConfig):
7779
class RolloutConfig(BaseConfig):
7880
_mutable_fields = {"max_model_len"}
7981

80-
name: Optional[str] = None
82+
name: Optional[str] = MISSING
8183
mode: str = "sync"
8284

8385
temperature: float = 1.0

0 commit comments

Comments
 (0)