Skip to content

Commit 4a2d266

Browse files
authored
[rollout] feat: deprecate all rollout sharding manager (#3285)
### What does this PR do? Deprecate all rollout sharding manager and replaced by `trainer_mode` and `rollout_mode` in hybrid worker.
1 parent d16b1b5 commit 4a2d266

29 files changed

+800
-770
lines changed

.github/workflows/e2e_ppo_trainer_megatron_sglang.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ jobs:
284284
- name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek)
285285
run: |
286286
ray stop --force
287+
MEGATRON_CI_DISABLE_EXPANDABLE_SEGMENTS=1 \
287288
ADV_ESTIMATOR=grpo USE_DUMMY_MODEL=True DUMMY_MODEL_CONFIG_PATH=tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json \
288289
PPO_MAX_TOKEN_LEN=512 FWD_MAX_TOKEN_LEN=512 \
289290
MAX_PROMPT_LENGTH=256 MAX_RESPONSE_LENGTH=256 \

.github/workflows/sgl.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# - `special_sanity`: a suite of quick sanity tests
1313
# - `special_standalone`: a set of test that are designed to run in dedicated environments
1414

15-
# Accelerators for tests
15+
# Accelerators for tests
1616
# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`.
1717
# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment.
1818

@@ -79,7 +79,7 @@ permissions:
7979
jobs:
8080
sgl:
8181
runs-on: [L20x8]
82-
timeout-minutes: 20 # Increase this timeout value as needed
82+
timeout-minutes: 35 # Increase this timeout value as needed
8383
env:
8484
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
8585
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
@@ -142,4 +142,4 @@ jobs:
142142
- name: Test the latest SGLang Rollout async with multimodal delta
143143
run: |
144144
cd tests/workers/rollout
145-
pytest -s test_sglang_async_rollout_multimodal_delta.py
145+
pytest -s test_sglang_async_rollout_multimodal_delta.py

recipe/one_step_off_policy/fsdp_workers.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,10 @@
4040
from verl.utils.model import get_generation_config, update_model_config
4141
from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer
4242
from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max
43+
from verl.workers.config import HFModelConfig, RolloutConfig
4344
from verl.workers.fsdp_workers import ActorRolloutRefWorker as ARRWorker
4445
from verl.workers.fsdp_workers import CriticWorker
46+
from verl.workers.rollout import get_rollout_class
4547

4648
logger = logging.getLogger(__file__)
4749
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
@@ -204,20 +206,12 @@ def init_model(self):
204206
rollout_name = self.config.rollout.name
205207
assert rollout_name == "vllm"
206208

207-
from verl.workers.rollout.vllm_rollout import vLLMRollout
209+
rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)
210+
model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model, dataclass_type=HFModelConfig)
208211

209212
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
210-
211-
from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout
212-
213-
vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout
214-
rollout = vllm_rollout_cls(
215-
model_path=local_path,
216-
config=self.config.rollout,
217-
tokenizer=self.tokenizer,
218-
model_hf_config=actor_model_config,
219-
device_mesh=rollout_device_mesh,
220-
trust_remote_code=trust_remote_code,
213+
rollout = get_rollout_class(rollout_config.name, rollout_config.mode)(
214+
config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh
221215
)
222216
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
223217
from .vllm_sharding_manager import VLLMShardingManager

recipe/one_step_off_policy/megatron_workers.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,24 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import copy
1617
import logging
1718
import os
1819

1920
import torch
2021
import torch.distributed
21-
from omegaconf import DictConfig, OmegaConf
22+
from omegaconf import DictConfig, OmegaConf, open_dict
2223

2324
from verl.single_controller.base.decorator import Dispatch, register
25+
from verl.utils.config import omega_conf_to_dataclass
2426
from verl.utils.debug import (
2527
log_gpu_memory_usage,
2628
)
2729
from verl.utils.device import get_device_name, get_torch_device
28-
from verl.utils.fs import copy_to_local
30+
from verl.workers.config import HFModelConfig, RolloutConfig
2931
from verl.workers.megatron_workers import ActorRolloutRefWorker as ARRWorker
3032
from verl.workers.megatron_workers import CriticWorker, RewardModelWorker
33+
from verl.workers.rollout import get_rollout_class
3134

3235
logger = logging.getLogger(__file__)
3336
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
@@ -145,8 +148,6 @@ def init_model(self):
145148
assert self.config.rollout.name == "vllm"
146149
assert self.config.rollout.mode == "sync"
147150

148-
from verl.workers.rollout.vllm_rollout import vLLMRollout
149-
150151
from .vllm_sharding_manager import VLLMShardingManager
151152

152153
# NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor,
@@ -162,17 +163,17 @@ def init_model(self):
162163
)
163164
log_gpu_memory_usage("Before building vllm rollout", logger=None)
164165

165-
local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.get("use_shm", False))
166-
from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout
167-
168-
vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout
169-
rollout = vllm_rollout_cls(
170-
model_path=local_path,
171-
config=self.config.rollout,
172-
tokenizer=self.tokenizer,
173-
model_hf_config=self.hf_config,
174-
device_mesh=rollout_device_mesh,
175-
trust_remote_code=trust_remote_code,
166+
rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)
167+
# (vermouth1992). self.config.model in megatron differs from that of fsdp in the override_config.
168+
# To workaround this we deepcopy self.config.model and make them compatible
169+
omega_model_config = copy.deepcopy(self.config.model)
170+
with open_dict(omega_model_config):
171+
override_config = omega_model_config.override_config.pop("model_config")
172+
omega_model_config.override_config = override_config
173+
174+
model_config: HFModelConfig = omega_conf_to_dataclass(omega_model_config, dataclass_type=HFModelConfig)
175+
rollout = get_rollout_class(rollout_config.name, rollout_config.mode)(
176+
config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh
176177
)
177178
log_gpu_memory_usage("After building vllm rollout", logger=logger)
178179

recipe/sppo/sppo_worker.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,7 @@ def init_model(self):
9191
)
9292

9393
if self._is_rollout:
94-
self.rollout, self.rollout_sharding_manager = self._build_rollout(
95-
trust_remote_code=self.config.model.get("trust_remote_code", False)
96-
)
94+
self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False))
9795

9896
if self._is_ref:
9997
self.ref_module_fsdp = self._build_model_optimizer(

tests/workers/rollout/rollout_vllm/test_vllm_model_rope_scaling.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
import torch.distributed
2020
import torch.distributed as dist
2121
from omegaconf import OmegaConf
22-
from transformers import AutoConfig, AutoTokenizer
22+
from transformers import AutoTokenizer
2323

2424
from verl import DataProto
25+
from verl.utils.config import omega_conf_to_dataclass
2526
from verl.utils.distributed import initialize_global_process_group
2627
from verl.utils.model import compute_position_id_with_mask
28+
from verl.workers.config import HFModelConfig, RolloutConfig
2729
from verl.workers.rollout.vllm_rollout.vllm_rollout_spmd import vLLMRollout
2830

2931

@@ -36,7 +38,7 @@ def test_vllm_rollout_with_yarn_position_embeddings():
3638
model_path = os.path.expanduser("~/models/OldKingMeister/Qwen2.5-1.5B-Instruct-YaRN")
3739
config = OmegaConf.create(
3840
{
39-
"model_path": model_path,
41+
"name": "vllm",
4042
"prompt_length": 35000,
4143
"response_length": 512,
4244
"dtype": "bfloat16",
@@ -56,26 +58,27 @@ def test_vllm_rollout_with_yarn_position_embeddings():
5658
"do_sample": False,
5759
},
5860
"tensor_model_parallel_size": 4,
59-
"trust_remote_code": True,
6061
"calculate_log_probs": False,
6162
"do_sample": False,
6263
"temperature": 0.0,
6364
"max_num_batched_tokens": 35000 + 512,
6465
}
6566
)
6667

67-
tokenizer = AutoTokenizer.from_pretrained(config.model_path, trust_remote_code=True, padding_side="left")
68+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="left")
6869
tokenizer.pad_token = tokenizer.eos_token
69-
model_hf_config = AutoConfig.from_pretrained(config.model_path)
7070

7171
# do_sample=False for temperate=0 deterministic
7272
input_dataproto = prepare_input_dataproto(tokenizer, config, validate=True, do_sample=False)
7373

74+
rollout_config: RolloutConfig = omega_conf_to_dataclass(config, dataclass_type=RolloutConfig)
75+
model_config = HFModelConfig(path=model_path)
76+
model_config.tokenizer.pad_token = tokenizer.eos_token
77+
7478
vllm_rollout = vLLMRollout(
75-
model_path=config.model_path,
76-
config=config,
77-
tokenizer=tokenizer,
78-
model_hf_config=model_hf_config,
79+
config=rollout_config,
80+
model_config=model_config,
81+
device_mesh=None,
7982
)
8083
# rollout
8184
rollout_response = vllm_rollout.generate_sequences(

tests/workers/rollout/test_sglang_async_rollout_mcp_tools.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from verl.tools.mcp_search_tool import MCPSearchTool
3131
from verl.tools.schemas import ToolResponse
3232
from verl.tools.utils.mcp_clients.McpClientManager import MCPClientManager
33+
from verl.utils.config import omega_conf_to_dataclass
34+
from verl.workers.config import HFModelConfig, RolloutConfig
3335
from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message
3436
from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout
3537

@@ -115,18 +117,18 @@ def get_search_messages():
115117

116118

117119
class TestRolloutWithMCPSearchTools:
120+
local_model_path = "Qwen/Qwen2.5-0.5B"
121+
118122
@pytest.fixture
119123
def qwen_tokenizer(self):
120-
local_model_path = "Qwen/Qwen2.5-0.5B"
121-
tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left")
124+
tokenizer = AutoTokenizer.from_pretrained(self.local_model_path, padding_side="left")
122125
tokenizer.pad_token = tokenizer.eos_token
123126
return tokenizer
124127

125128
# we only need this for tokenizer
126129
@pytest.fixture
127130
def qwen_model_config(self):
128-
local_model_path = "Qwen/Qwen2.5-0.5B"
129-
config = AutoConfig.from_pretrained(local_model_path)
131+
config = AutoConfig.from_pretrained(self.local_model_path)
130132
return config
131133

132134
@pytest.fixture
@@ -269,11 +271,12 @@ def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_model_config)
269271
patch.object(SGLangRollout, "_init_inference_engine", return_value=None),
270272
patch.object(SGLangRollout, "_init_sampling_params", return_value=None),
271273
):
274+
rollout_config: RolloutConfig = omega_conf_to_dataclass(search_rollout_config, dataclass_type=RolloutConfig)
275+
model_config = HFModelConfig(path=self.local_model_path)
272276
rollout = SGLangRollout(
273-
actor_module="",
274-
config=search_rollout_config,
275-
processing_class=qwen_tokenizer,
276-
model_hf_config=qwen_model_config,
277+
config=rollout_config,
278+
model_config=model_config,
279+
device_mesh=None,
277280
)
278281
rollout.sampling_params = {
279282
"n": 1,

tests/workers/rollout/test_sglang_async_rollout_search_tools.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
ToolResponse,
3535
)
3636
from verl.tools.search_tool import SearchTool
37+
from verl.utils.config import omega_conf_to_dataclass
38+
from verl.workers.config import HFModelConfig, RolloutConfig
3739
from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message
3840
from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout
3941

@@ -87,18 +89,18 @@ def get_search_messages():
8789

8890

8991
class TestRolloutWithSearchTools:
92+
local_model_path = "Qwen/Qwen2.5-0.5B"
93+
9094
@pytest.fixture
9195
def qwen_tokenizer(self):
92-
local_model_path = "Qwen/Qwen2.5-0.5B"
93-
tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left")
96+
tokenizer = AutoTokenizer.from_pretrained(self.local_model_path, padding_side="left")
9497
tokenizer.pad_token = tokenizer.eos_token
9598
return tokenizer
9699

97100
# we only need this for tokenizer
98101
@pytest.fixture
99102
def qwen_model_config(self):
100-
local_model_path = "Qwen/Qwen2.5-0.5B"
101-
config = AutoConfig.from_pretrained(local_model_path)
103+
config = AutoConfig.from_pretrained(self.local_model_path)
102104
return config
103105

104106
@pytest.fixture
@@ -172,11 +174,12 @@ def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_model_config)
172174
patch.object(SGLangRollout, "_init_inference_engine", return_value=None),
173175
patch.object(SGLangRollout, "_init_sampling_params", return_value=None),
174176
):
177+
rollout_config: RolloutConfig = omega_conf_to_dataclass(search_rollout_config, dataclass_type=RolloutConfig)
178+
model_config = HFModelConfig(path=self.local_model_path)
175179
rollout = SGLangRollout(
176-
actor_module="",
177-
config=search_rollout_config,
178-
processing_class=qwen_tokenizer,
179-
model_hf_config=qwen_model_config,
180+
config=rollout_config,
181+
model_config=model_config,
182+
device_mesh=None,
180183
)
181184
rollout.sampling_params = {
182185
"n": 1,
@@ -193,11 +196,12 @@ def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_model_config)
193196
def test_tools_registration(
194197
self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config
195198
):
199+
rollout_config: RolloutConfig = omega_conf_to_dataclass(search_rollout_config, dataclass_type=RolloutConfig)
200+
model_config = HFModelConfig(path=self.local_model_path)
196201
rollout = SGLangRollout(
197-
actor_module="",
198-
config=search_rollout_config,
199-
processing_class=qwen_tokenizer,
200-
model_hf_config=qwen_model_config,
202+
config=rollout_config,
203+
model_config=model_config,
204+
device_mesh=None,
201205
)
202206
assert len(rollout._tool_schemas) == 1
203207
assert "search" in rollout._tool_map.keys()
@@ -220,11 +224,12 @@ def test_rollout_req_creation(
220224
qwen_model_config,
221225
search_data_proto,
222226
):
227+
rollout_config: RolloutConfig = omega_conf_to_dataclass(search_rollout_config, dataclass_type=RolloutConfig)
228+
model_config = HFModelConfig(path=self.local_model_path)
223229
rollout = SGLangRollout(
224-
actor_module="",
225-
config=search_rollout_config,
226-
processing_class=qwen_tokenizer,
227-
model_hf_config=qwen_model_config,
230+
config=rollout_config,
231+
model_config=model_config,
232+
device_mesh=None,
228233
)
229234
req_list = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)
230235
assert len(req_list) == 1

tests/workers/rollout/test_sglang_async_rollout_sf_tools.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
)
4242
from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message
4343
from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout
44+
from verl.utils.config import omega_conf_to_dataclass
45+
from verl.workers.config import HFModelConfig, RolloutConfig
4446

4547
sandbox_url = ""
4648

@@ -148,18 +150,18 @@ def wrapper(*args, **kwargs):
148150

149151

150152
class TestRolloutWithTools:
153+
local_model_path = "Qwen/Qwen2.5-0.5B"
154+
151155
@pytest.fixture
152156
def qwen_tokenizer(self):
153-
local_model_path = "Qwen/Qwen2.5-0.5B"
154-
tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left")
157+
tokenizer = AutoTokenizer.from_pretrained(self.local_model_path, padding_side="left")
155158
tokenizer.pad_token = tokenizer.eos_token
156159
return tokenizer
157160

158161
# we only need this for tokenizer
159162
@pytest.fixture
160163
def qwen_model_config(self):
161-
local_model_path = "Qwen/Qwen2.5-0.5B"
162-
config = AutoConfig.from_pretrained(local_model_path)
164+
config = AutoConfig.from_pretrained(self.local_model_path)
163165
return config
164166

165167
@pytest.fixture
@@ -227,11 +229,12 @@ def mock_rollout(self, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model
227229
with patch.object(SGLangRollout, "_init_distributed_env", return_value=None), patch.object(
228230
SGLangRollout, "_init_inference_engine", return_value=None
229231
), patch.object(SGLangRollout, "_init_sampling_params", return_value=None):
232+
rollout_config: RolloutConfig = omega_conf_to_dataclass(sandbox_fusion_rollout_config, dataclass_type=RolloutConfig)
233+
model_config = HFModelConfig(path=self.local_model_path)
230234
rollout = SGLangRollout(
231-
actor_module="",
232-
config=sandbox_fusion_rollout_config,
233-
processing_class=qwen_tokenizer,
234-
model_hf_config=qwen_model_config,
235+
config=rollout_config,
236+
model_config=model_config,
237+
device_mesh=None,
235238
)
236239
# set default sampling_params
237240
rollout.sampling_params = {

0 commit comments

Comments
 (0)