Skip to content

Commit b4006a1

Browse files
committed
fix test
1 parent 0ad4f4f commit b4006a1

File tree

1 file changed

+38
-1
lines changed

1 file changed

+38
-1
lines changed

verl/workers/rollout/async_server.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import socket
1818
from abc import ABC, abstractmethod
1919
from contextlib import asynccontextmanager
20-
from typing import Any
20+
from typing import Any, Optional
2121

2222
import fastapi
2323
import ray
@@ -110,3 +110,40 @@ async def wake_up(self):
110110
async def sleep(self):
111111
"""Sleep engine to offload model weights and discard kv cache."""
112112
raise NotImplementedError
113+
114+
115+
def async_server_class(
116+
rollout_backend: str, rollout_backend_module: Optional[str] = None, rollout_backend_class: Optional[str] = None
117+
) -> type[AsyncServerBase]:
118+
"""Get async server class.
119+
120+
Args:
121+
rollout_backend: str, rollout backend type (alias), should be "vllm" or "sglang".
122+
rollout_backend_module: Optional[str], import path of the rollout backend.
123+
rollout_backend_class: Optional[str], class name of the rollout backend.
124+
125+
Returns:
126+
Type[AsyncServerBase]: async server class.
127+
"""
128+
if rollout_backend_class is None and rollout_backend_module is None:
129+
# If both are None, use the default backend class
130+
# Do not change the original import behavior
131+
# importlib.import_module and from ... import ... have subtle differences in ray
132+
133+
if rollout_backend == "vllm":
134+
from verl.workers.rollout.vllm_rollout.vllm_async_server import AsyncvLLMServer
135+
136+
return AsyncvLLMServer
137+
elif rollout_backend == "sglang":
138+
from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSGLangServer
139+
140+
return AsyncSGLangServer
141+
else:
142+
raise NotImplementedError(f"rollout backend {rollout_backend} is not supported")
143+
144+
if rollout_backend_module is None or rollout_backend_class is None:
145+
raise ValueError("rollout_backend_module and rollout_backend_class must be both provided for customization")
146+
147+
from verl.utils.import_utils import load_extern_type
148+
149+
return load_extern_type(rollout_backend_module, rollout_backend_class)

0 commit comments

Comments
 (0)