|
17 | 17 | import socket |
18 | 18 | from abc import ABC, abstractmethod |
19 | 19 | from contextlib import asynccontextmanager |
20 | | -from typing import Any |
| 20 | +from typing import Any, Optional |
21 | 21 |
|
22 | 22 | import fastapi |
23 | 23 | import ray |
@@ -110,3 +110,40 @@ async def wake_up(self): |
110 | 110 | async def sleep(self): |
111 | 111 | """Sleep engine to offload model weights and discard kv cache.""" |
112 | 112 | raise NotImplementedError |
| 113 | + |
| 114 | + |
| 115 | +def async_server_class( |
| 116 | + rollout_backend: str, rollout_backend_module: Optional[str] = None, rollout_backend_class: Optional[str] = None |
| 117 | +) -> type[AsyncServerBase]: |
| 118 | + """Get async server class. |
| 119 | +
|
| 120 | + Args: |
| 121 | + rollout_backend: str, rollout backend type (alias), should be "vllm" or "sglang". |
| 122 | + rollout_backend_module: Optional[str], import path of the rollout backend. |
| 123 | + rollout_backend_class: Optional[str], class name of the rollout backend. |
| 124 | +
|
| 125 | + Returns: |
| 126 | + Type[AsyncServerBase]: async server class. |
| 127 | + """ |
| 128 | + if rollout_backend_class is None and rollout_backend_module is None: |
| 129 | + # If both are None, use the default backend class |
| 130 | + # Do not change the original import behavior |
| 131 | + # importlib.import_module and from ... import ... have subtle differences in ray |
| 132 | + |
| 133 | + if rollout_backend == "vllm": |
| 134 | + from verl.workers.rollout.vllm_rollout.vllm_async_server import AsyncvLLMServer |
| 135 | + |
| 136 | + return AsyncvLLMServer |
| 137 | + elif rollout_backend == "sglang": |
| 138 | + from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSGLangServer |
| 139 | + |
| 140 | + return AsyncSGLangServer |
| 141 | + else: |
| 142 | + raise NotImplementedError(f"rollout backend {rollout_backend} is not supported") |
| 143 | + |
| 144 | + if rollout_backend_module is None or rollout_backend_class is None: |
| 145 | + raise ValueError("rollout_backend_module and rollout_backend_class must be both provided for customization") |
| 146 | + |
| 147 | + from verl.utils.import_utils import load_extern_type |
| 148 | + |
| 149 | + return load_extern_type(rollout_backend_module, rollout_backend_class) |
0 commit comments