|
16 | 16 | import os |
17 | 17 | from abc import ABC, abstractmethod |
18 | 18 | from enum import Enum |
19 | | -from typing import Optional |
| 19 | +from typing import Callable, Optional |
20 | 20 |
|
21 | 21 | from pydantic import BaseModel |
22 | 22 | from ray.actor import ActorHandle |
@@ -85,7 +85,11 @@ def __init__( |
85 | 85 | self.config = omega_conf_to_dataclass(config) |
86 | 86 | self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig) |
87 | 87 |
|
88 | | - self.world_size = self.config.tensor_model_parallel_size * self.config.data_parallel_size |
| 88 | + self.world_size = ( |
| 89 | + self.config.tensor_model_parallel_size |
| 90 | + * self.config.data_parallel_size |
| 91 | + * self.config.pipeline_model_parallel_size |
| 92 | + ) |
89 | 93 | self.gpus_per_node = min(gpus_per_node, self.world_size) |
90 | 94 | assert self.world_size % self.gpus_per_node == 0, ( |
91 | 95 | f"world_size {self.world_size} must be divisible by gpus_per_node {self.gpus_per_node}" |
@@ -171,32 +175,57 @@ async def sleep(self): |
171 | 175 | await asyncio.gather(*[server.sleep.remote() for server in self.servers]) |
172 | 176 |
|
173 | 177 |
|
| 178 | +class RolloutReplicaRegistry: |
| 179 | + """Factory for managing rollout replica implementations.""" |
| 180 | + |
| 181 | + _registry: dict[str, Callable[[], type[RolloutReplica]]] = {} |
| 182 | + |
| 183 | + @classmethod |
| 184 | + def register(cls, name: str, loader: Callable[[], type[RolloutReplica]]) -> None: |
| 185 | + """Register a new rollout replica type.""" |
| 186 | + cls._registry[name] = loader |
| 187 | + |
| 188 | + @classmethod |
| 189 | + def get(cls, name: str) -> type[RolloutReplica]: |
| 190 | + """Get a rollout replica class by name.""" |
| 191 | + if name not in cls._registry: |
| 192 | + raise ValueError(f"Unknown rollout mode: {name}. Available: {list(cls._registry.keys())}") |
| 193 | + return cls._registry[name]() |
| 194 | + |
| 195 | + |
| 196 | +# Loader functions for built-in types |
| 197 | +def _load_vllm(): |
| 198 | + from verl.workers.rollout.vllm_rollout.vllm_async_server import vLLMReplica |
| 199 | + |
| 200 | + return vLLMReplica |
| 201 | + |
| 202 | + |
| 203 | +def _load_sglang(): |
| 204 | + os.environ["SGLANG_USE_CPU_ENGINE"] = "1" |
| 205 | + |
| 206 | + try: |
| 207 | + import vllm # noqa: F401 |
| 208 | + except ImportError: |
| 209 | + import sys |
| 210 | + from unittest.mock import Mock |
| 211 | + |
| 212 | + mock_vllm = Mock() |
| 213 | + mock_vllm._custom_ops = Mock() |
| 214 | + mock_vllm._custom_ops.scaled_fp8_quant = Mock() |
| 215 | + sys.modules["vllm"] = mock_vllm |
| 216 | + sys.modules["vllm._custom_ops"] = mock_vllm._custom_ops |
| 217 | + |
| 218 | + from verl.workers.rollout.sglang_rollout.async_sglang_server import SGLangReplica |
| 219 | + |
| 220 | + del os.environ["SGLANG_USE_CPU_ENGINE"] |
| 221 | + return SGLangReplica |
| 222 | + |
| 223 | + |
| 224 | +# Register built-in types |
| 225 | +RolloutReplicaRegistry.register("vllm", _load_vllm) |
| 226 | +RolloutReplicaRegistry.register("sglang", _load_sglang) |
| 227 | + |
| 228 | + |
| 229 | +# Original function for backward compatibility |
174 | 230 | def get_rollout_replica_class(rollout: str) -> type[RolloutReplica]: |
175 | | - if rollout == "vllm": |
176 | | - from verl.workers.rollout.vllm_rollout.vllm_async_server import vLLMReplica |
177 | | - |
178 | | - return vLLMReplica |
179 | | - elif rollout == "sglang": |
180 | | - # NOTE: verl driver is cpu only, avoid sglang fp8 quantization import error. |
181 | | - os.environ["SGLANG_USE_CPU_ENGINE"] = "1" |
182 | | - |
183 | | - # TODO: remove this once we bump to sglang>=0.5.1 |
184 | | - try: |
185 | | - import vllm # noqa: F401 |
186 | | - except ImportError: |
187 | | - import sys |
188 | | - from unittest.mock import Mock |
189 | | - |
190 | | - mock_vllm = Mock() |
191 | | - mock_vllm._custom_ops = Mock() |
192 | | - mock_vllm._custom_ops.scaled_fp8_quant = Mock() |
193 | | - |
194 | | - sys.modules["vllm"] = mock_vllm |
195 | | - sys.modules["vllm._custom_ops"] = mock_vllm._custom_ops |
196 | | - |
197 | | - from verl.workers.rollout.sglang_rollout.async_sglang_server import SGLangReplica |
198 | | - |
199 | | - del os.environ["SGLANG_USE_CPU_ENGINE"] |
200 | | - return SGLangReplica |
201 | | - else: |
202 | | - raise ValueError(f"Unknown rollout mode: {rollout}") |
| 231 | + return RolloutReplicaRegistry.get(rollout) |
0 commit comments