Skip to content

Commit 743f58d

Browse files
authored
[rollout] chore: Misc changes for extending internal compatibility (volcengine#3701)
### What does this PR do? * New config field: * rollout: `pipeline_model_parallel_size` for internal compatibility * ~~legacy_data: `agent_name` for default agent name if not specified in the rldataset~~ * Registry for `RolloutReplica` * `VERL_USE_EXTERNAL_MODULES` to import desired modules to trigger external registration ### Test Be covered by CI ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent abc54b6 commit 743f58d

File tree

10 files changed

+102
-41
lines changed

10 files changed

+102
-41
lines changed

verl/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from .protocol import DataProto
2424
from .utils.device import is_npu_available
25+
from .utils.import_utils import import_external_libs
2526
from .utils.logging_utils import set_basic_config
2627

2728
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
@@ -35,6 +36,13 @@
3536

3637
__all__ = ["DataProto", "__version__"]
3738

39+
40+
modules = os.getenv("VERL_USE_EXTERNAL_MODULES", "")
41+
if modules:
42+
modules = modules.split(",")
43+
import_external_libs(modules)
44+
45+
3846
if os.getenv("VERL_USE_MODELSCOPE", "False").lower() == "true":
3947
if importlib.util.find_spec("modelscope") is None:
4048
raise ImportError("You are using the modelscope hub, please install modelscope by `pip install modelscope -U`")

verl/experimental/agent_loop/agent_loop.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,7 @@ def _initialize_llm_servers(self):
789789
rollout_world_size = (
790790
self.config.actor_rollout_ref.rollout.tensor_model_parallel_size
791791
* self.config.actor_rollout_ref.rollout.data_parallel_size
792+
* self.config.actor_rollout_ref.rollout.pipeline_model_parallel_size
792793
)
793794
world_size = (
794795
self.worker_group.world_size

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ actor_rollout_ref:
190190
tensor_model_parallel_size: 2
191191
data_parallel_size: 1
192192
expert_parallel_size: 1
193+
pipeline_model_parallel_size: 1
193194
max_num_batched_tokens: 8192
194195
max_model_len: null
195196
max_num_seqs: 1024

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ actor_rollout_ref:
177177
tensor_model_parallel_size: 2
178178
data_parallel_size: 1
179179
expert_parallel_size: 1
180+
pipeline_model_parallel_size: 1
180181
max_num_batched_tokens: 8192
181182
max_model_len: null
182183
max_num_seqs: 1024

verl/trainer/config/rollout/rollout.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ data_parallel_size: 1
5555
# EP size for rollout
5656
expert_parallel_size: 1
5757

58+
# PP size for rollout.
59+
pipeline_model_parallel_size: 1
60+
5861
# max number of tokens in a batch
5962
max_num_batched_tokens: 8192
6063

verl/utils/memory_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424

2525
from verl.utils.device import get_torch_device, is_cuda_available
2626

27-
logger = logging.getLogger(__name__)
27+
logger = logging.getLogger(__file__)
28+
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
2829

2930

3031
def aggressive_empty_cache(force_sync: bool = True, max_retries: int = 3) -> None:

verl/workers/config/rollout.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ class RolloutConfig(BaseConfig):
121121
data_parallel_size: int = 1
122122
expert_parallel_size: int = 1
123123
tensor_model_parallel_size: int = 2
124+
pipeline_model_parallel_size: int = 1
124125
max_num_batched_tokens: int = 8192
125126

126127
# TODO: enable train_kwargs
@@ -183,3 +184,9 @@ def __post_init__(self):
183184
assert self.expert_parallel_size == (self.tensor_model_parallel_size * self.data_parallel_size), (
184185
"expert_parallel_size must be equal to tensor_model_parallel_size * data_parallel_size"
185186
)
187+
188+
if self.pipeline_model_parallel_size > 1:
189+
if self.name == "vllm" or self.name == "sglang":
190+
raise NotImplementedError(
191+
f"Current rollout {self.name=} not implemented pipeline_model_parallel_size > 1 yet."
192+
)

verl/workers/fsdp_workers.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -571,19 +571,24 @@ def _build_rollout(self, trust_remote_code=False):
571571

572572
# 2. build rollout device mesh
573573
infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size
574-
dp = self.world_size // infer_tp
575-
assert self.world_size % infer_tp == 0, (
576-
f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
574+
infer_pp = self.config.rollout.pipeline_model_parallel_size
575+
infer_world_size = infer_tp * infer_pp
576+
dp = self.world_size // infer_world_size
577+
assert self.world_size % infer_world_size == 0, (
578+
f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}"
577579
)
578580
rollout_device_mesh = init_device_mesh(
579-
device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
581+
device_name, mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"]
580582
)
581583
rollout_name = self.config.rollout.name
582584

583585
if rollout_name == "hf":
584586
self._register_dispatch_collect_info("rollout", dp_rank=self.rank, is_collect=True)
585587
else:
586-
is_collect = rollout_device_mesh["infer_tp"].get_local_rank() == 0
588+
is_collect = (
589+
rollout_device_mesh["infer_tp"].get_local_rank() == 0
590+
and rollout_device_mesh["infer_pp"].get_local_rank() == 0
591+
)
587592
self._register_dispatch_collect_info(
588593
"rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect
589594
)

verl/workers/megatron_workers.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -397,15 +397,20 @@ def _build_rollout(self, trust_remote_code=False):
397397

398398
# 2. build rollout device mesh
399399
infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size
400-
dp = self.world_size // infer_tp
401-
assert self.world_size % infer_tp == 0, (
402-
f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
400+
infer_pp = self.config.rollout.pipeline_model_parallel_size
401+
infer_world_size = infer_tp * infer_pp
402+
dp = self.world_size // infer_world_size
403+
assert self.world_size % infer_world_size == 0, (
404+
f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}"
403405
)
404406
rollout_device_mesh = init_device_mesh(
405-
get_device_name(), mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
407+
get_device_name(), mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"]
406408
)
407409

408-
is_collect = rollout_device_mesh["infer_tp"].get_local_rank() == 0
410+
is_collect = (
411+
rollout_device_mesh["infer_tp"].get_local_rank() == 0
412+
and rollout_device_mesh["infer_pp"].get_local_rank() == 0
413+
)
409414
self._register_dispatch_collect_info(
410415
"rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect
411416
)

verl/workers/rollout/replica.py

Lines changed: 59 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
from abc import ABC, abstractmethod
1818
from enum import Enum
19-
from typing import Optional
19+
from typing import Callable, Optional
2020

2121
from pydantic import BaseModel
2222
from ray.actor import ActorHandle
@@ -85,7 +85,11 @@ def __init__(
8585
self.config = omega_conf_to_dataclass(config)
8686
self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig)
8787

88-
self.world_size = self.config.tensor_model_parallel_size * self.config.data_parallel_size
88+
self.world_size = (
89+
self.config.tensor_model_parallel_size
90+
* self.config.data_parallel_size
91+
* self.config.pipeline_model_parallel_size
92+
)
8993
self.gpus_per_node = min(gpus_per_node, self.world_size)
9094
assert self.world_size % self.gpus_per_node == 0, (
9195
f"world_size {self.world_size} must be divisible by gpus_per_node {self.gpus_per_node}"
@@ -171,32 +175,57 @@ async def sleep(self):
171175
await asyncio.gather(*[server.sleep.remote() for server in self.servers])
172176

173177

178+
class RolloutReplicaRegistry:
179+
"""Factory for managing rollout replica implementations."""
180+
181+
_registry: dict[str, Callable[[], type[RolloutReplica]]] = {}
182+
183+
@classmethod
184+
def register(cls, name: str, loader: Callable[[], type[RolloutReplica]]) -> None:
185+
"""Register a new rollout replica type."""
186+
cls._registry[name] = loader
187+
188+
@classmethod
189+
def get(cls, name: str) -> type[RolloutReplica]:
190+
"""Get a rollout replica class by name."""
191+
if name not in cls._registry:
192+
raise ValueError(f"Unknown rollout mode: {name}. Available: {list(cls._registry.keys())}")
193+
return cls._registry[name]()
194+
195+
196+
# Loader functions for built-in types
197+
def _load_vllm():
198+
from verl.workers.rollout.vllm_rollout.vllm_async_server import vLLMReplica
199+
200+
return vLLMReplica
201+
202+
203+
def _load_sglang():
204+
os.environ["SGLANG_USE_CPU_ENGINE"] = "1"
205+
206+
try:
207+
import vllm # noqa: F401
208+
except ImportError:
209+
import sys
210+
from unittest.mock import Mock
211+
212+
mock_vllm = Mock()
213+
mock_vllm._custom_ops = Mock()
214+
mock_vllm._custom_ops.scaled_fp8_quant = Mock()
215+
sys.modules["vllm"] = mock_vllm
216+
sys.modules["vllm._custom_ops"] = mock_vllm._custom_ops
217+
218+
from verl.workers.rollout.sglang_rollout.async_sglang_server import SGLangReplica
219+
220+
del os.environ["SGLANG_USE_CPU_ENGINE"]
221+
return SGLangReplica
222+
223+
224+
# Register built-in types
225+
RolloutReplicaRegistry.register("vllm", _load_vllm)
226+
RolloutReplicaRegistry.register("sglang", _load_sglang)
227+
228+
229+
# Original function for backward compatibility
174230
def get_rollout_replica_class(rollout: str) -> type[RolloutReplica]:
175-
if rollout == "vllm":
176-
from verl.workers.rollout.vllm_rollout.vllm_async_server import vLLMReplica
177-
178-
return vLLMReplica
179-
elif rollout == "sglang":
180-
# NOTE: verl driver is cpu only, avoid sglang fp8 quantization import error.
181-
os.environ["SGLANG_USE_CPU_ENGINE"] = "1"
182-
183-
# TODO: remove this once we bump to sglang>=0.5.1
184-
try:
185-
import vllm # noqa: F401
186-
except ImportError:
187-
import sys
188-
from unittest.mock import Mock
189-
190-
mock_vllm = Mock()
191-
mock_vllm._custom_ops = Mock()
192-
mock_vllm._custom_ops.scaled_fp8_quant = Mock()
193-
194-
sys.modules["vllm"] = mock_vllm
195-
sys.modules["vllm._custom_ops"] = mock_vllm._custom_ops
196-
197-
from verl.workers.rollout.sglang_rollout.async_sglang_server import SGLangReplica
198-
199-
del os.environ["SGLANG_USE_CPU_ENGINE"]
200-
return SGLangReplica
201-
else:
202-
raise ValueError(f"Unknown rollout mode: {rollout}")
231+
return RolloutReplicaRegistry.get(rollout)

0 commit comments

Comments
 (0)