Skip to content

Commit 15f71a6

Browse files
authored
[rollout,vllm] feat: unify vllm and sglang method to async (volcengine#2982)
### What does this PR do? Change vLLM method to async to unify with SGLang.
1 parent f46f5bb commit 15f71a6

File tree

2 files changed

+43
-27
lines changed

2 files changed

+43
-27
lines changed

verl/workers/rollout/sglang_rollout/sglang_rollout.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1341,7 +1341,10 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in
13411341

13421342
return req_list
13431343

1344+
# ==================== server mode public methods ====================
1345+
13441346
async def chat_completion(self, json_request):
1347+
"""OpenAI chat completion API."""
13451348
assert self._tp_rank == 0, "only called in tp rank 0"
13461349
_input_ids = None
13471350
_attention_mask = None
@@ -1419,19 +1422,21 @@ async def chat_completion(self, json_request):
14191422
async def generate(
14201423
self, prompt_ids: torch.Tensor, sampling_params: dict[str, Any], request_id: str
14211424
) -> torch.Tensor:
1425+
"""Generate sequence with token-in-token-out."""
14221426
request_sampling_params = self.sampling_params.copy()
14231427
request_sampling_params.update(sampling_params)
14241428
output = await self._handle_engine_generate(prompt_ids, request_sampling_params)
14251429
return output["output_ids"]
14261430

14271431
async def wake_up(self):
1432+
"""Load model weights and build kv cache."""
14281433
if not self.is_sleep:
14291434
return
14301435
await self.sharding_manager.wake_up() # pylint: disable=C2801
14311436
self.is_sleep = False
14321437

1433-
# this function is left for uniform train-inference resharding
14341438
async def sleep(self):
1439+
"""Offload model weights and discard kv cache."""
14351440
if self.is_sleep:
14361441
return
14371442
await self.sharding_manager.sleep()

verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@
2626
- After inference, all the parameters that doesn't belong to this pp rank is freed.
2727
"""
2828

29+
import asyncio
2930
import getpass
3031
import logging
3132
import os
3233
import pickle
3334
import socket
34-
import threading
3535
from contextlib import contextmanager
3636
from copy import deepcopy
3737
from types import MethodType
@@ -42,6 +42,7 @@
4242
import torch
4343
import torch.distributed
4444
import zmq
45+
import zmq.asyncio
4546
from filelock import FileLock
4647
from omegaconf import DictConfig, ListConfig, OmegaConf
4748
from tensordict import TensorDict
@@ -446,12 +447,12 @@ def _init_zeromq(self) -> str:
446447
else:
447448
ip, port = self._get_free_port()
448449
address = f"tcp://{ip}:{port}"
449-
context = zmq.Context()
450+
context = zmq.asyncio.Context()
450451
self.socket = context.socket(zmq.REP)
451452
self.socket.bind(address)
452453

453-
self.loop_thread = threading.Thread(target=self._loop_forever)
454-
self.loop_thread.start()
454+
loop = asyncio.get_running_loop()
455+
self.zmq_loop_task = loop.create_task(self._loop_forever())
455456

456457
return address
457458

@@ -462,25 +463,22 @@ def _get_free_port(self):
462463
port = sock.getsockname()[1]
463464
return ip, port
464465

465-
def _loop_forever(self):
466+
async def _loop_forever(self):
466467
while True:
467-
message = self.socket.recv()
468+
message = await self.socket.recv()
468469
method, args, kwargs = pickle.loads(message)
469-
result = self.execute_method(method, *args, **kwargs)
470-
self.socket.send(pickle.dumps(result))
470+
result = await self._execute_method(method, *args, **kwargs)
471+
await self.socket.send(pickle.dumps(result))
471472

472-
def get_zeromq_address(self):
473-
return self.address
474-
475-
def init_worker(self, all_kwargs: list[dict[str, Any]]):
473+
def _init_worker(self, all_kwargs: list[dict[str, Any]]):
476474
"""Initialize worker engine."""
477475
all_kwargs[0]["rank"] = int(os.environ["RANK"])
478476
all_kwargs[0]["local_rank"] = 0 if not ray_noset_visible_devices() else int(os.environ.get("RAY_LOCAL_RANK", 0))
479477
self.vllm_config = all_kwargs[0]["vllm_config"]
480478
self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config)
481479
self.inference_engine.init_worker(all_kwargs)
482480

483-
def load_model(self, *args, **kwargs):
481+
def _load_model(self, *args, **kwargs):
484482
self.inference_engine.load_model(*args, **kwargs)
485483

486484
# inference engine is initialized now, update sharding manager
@@ -489,28 +487,41 @@ def load_model(self, *args, **kwargs):
489487

490488
_monkey_patch_compute_logits(self.inference_engine.worker.model_runner.model, len(self.tokenizer))
491489

492-
def sleep(self, *args, **kwargs):
490+
async def _execute_method(self, method: str | bytes, *args, **kwargs):
491+
if method == "init_worker":
492+
return self._init_worker(*args, **kwargs)
493+
elif method == "load_model":
494+
return self._load_model(*args, **kwargs)
495+
elif method == "sleep":
496+
return await self.sleep(*args, **kwargs)
497+
elif method == "wake_up":
498+
return await self.wake_up(*args, **kwargs)
499+
else:
500+
return self.inference_engine.execute_method(method, *args, **kwargs)
501+
502+
# ==================== server mode public methods ====================
503+
504+
def get_zeromq_address(self):
505+
return self.address
506+
507+
async def sleep(self, *args, **kwargs):
493508
"""Offload model weights and discard kv cache."""
494509
if self.is_sleep:
495510
return
496511
self.sharding_manager.__exit__(None, None, None)
497512
self.is_sleep = True
498513

499-
def wake_up(self, *args, **kwargs):
514+
async def wake_up(self, *args, **kwargs):
500515
"""Load model weights and build kv cache."""
501516
if not self.is_sleep:
502517
return
503518
self.sharding_manager.__enter__() # pylint: disable=C2801
504519
self.is_sleep = False
505520

506-
def execute_method(self, method: str | bytes, *args, **kwargs):
507-
if method == "init_worker":
508-
return self.init_worker(*args, **kwargs)
509-
elif method == "load_model":
510-
return self.load_model(*args, **kwargs)
511-
elif method == "sleep":
512-
return self.sleep(*args, **kwargs)
513-
elif method == "wake_up":
514-
return self.wake_up(*args, **kwargs)
515-
else:
516-
return self.inference_engine.execute_method(method, *args, **kwargs)
521+
async def generate(self, *args, **kwargs):
522+
"""Generate sequence with token-in-token-out."""
523+
raise NotImplementedError
524+
525+
async def chat_completion(self, json_request):
526+
"""OpenAI chat completion API."""
527+
raise NotImplementedError

0 commit comments

Comments
 (0)