Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion verl/workers/rollout/sglang_rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,7 +1341,10 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in

return req_list

# ==================== server mode public methods ====================

async def chat_completion(self, json_request):
"""OpenAI chat completion API."""
assert self._tp_rank == 0, "only called in tp rank 0"
_input_ids = None
_attention_mask = None
Expand Down Expand Up @@ -1419,19 +1422,21 @@ async def chat_completion(self, json_request):
async def generate(
self, prompt_ids: torch.Tensor, sampling_params: dict[str, Any], request_id: str
) -> torch.Tensor:
"""Generate sequence with token-in-token-out."""
request_sampling_params = self.sampling_params.copy()
request_sampling_params.update(sampling_params)
output = await self._handle_engine_generate(prompt_ids, request_sampling_params)
return output["output_ids"]

async def wake_up(self):
"""Load model weights and build kv cache."""
if not self.is_sleep:
return
await self.sharding_manager.wake_up() # pylint: disable=C2801
self.is_sleep = False

# this function is left for uniform train-inference resharding
async def sleep(self):
"""Offload model weights and discard kv cache."""
if self.is_sleep:
return
await self.sharding_manager.sleep()
Expand Down
63 changes: 37 additions & 26 deletions verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
- After inference, all the parameters that doesn't belong to this pp rank is freed.
"""

import asyncio
import getpass
import logging
import os
import pickle
import socket
import threading
from contextlib import contextmanager
from copy import deepcopy
from types import MethodType
Expand All @@ -42,6 +42,7 @@
import torch
import torch.distributed
import zmq
import zmq.asyncio
from filelock import FileLock
from omegaconf import DictConfig, OmegaConf
from tensordict import TensorDict
Expand Down Expand Up @@ -431,12 +432,12 @@ def _init_zeromq(self) -> str:
else:
ip, port = self._get_free_port()
address = f"tcp://{ip}:{port}"
context = zmq.Context()
context = zmq.asyncio.Context()
self.socket = context.socket(zmq.REP)
self.socket.bind(address)

self.loop_thread = threading.Thread(target=self._loop_forever)
self.loop_thread.start()
loop = asyncio.get_running_loop()
self.zmq_loop_task = loop.create_task(self._loop_forever())

return address

Expand All @@ -447,25 +448,22 @@ def _get_free_port(self):
port = sock.getsockname()[1]
return ip, port

def _loop_forever(self):
async def _loop_forever(self):
while True:
message = self.socket.recv()
message = await self.socket.recv()
method, args, kwargs = pickle.loads(message)
result = self.execute_method(method, *args, **kwargs)
self.socket.send(pickle.dumps(result))
result = await self._execute_method(method, *args, **kwargs)
await self.socket.send(pickle.dumps(result))

def get_zeromq_address(self):
return self.address

def init_worker(self, all_kwargs: list[dict[str, Any]]):
def _init_worker(self, all_kwargs: list[dict[str, Any]]):
"""Initialize worker engine."""
all_kwargs[0]["rank"] = int(os.environ["RANK"])
all_kwargs[0]["local_rank"] = 0 if not ray_noset_visible_devices() else int(os.environ.get("RAY_LOCAL_RANK", 0))
self.vllm_config = all_kwargs[0]["vllm_config"]
self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config)
self.inference_engine.init_worker(all_kwargs)

def load_model(self, *args, **kwargs):
def _load_model(self, *args, **kwargs):
self.inference_engine.load_model(*args, **kwargs)

# inference engine is initialized now, update sharding manager
Expand All @@ -474,28 +472,41 @@ def load_model(self, *args, **kwargs):

_monkey_patch_compute_logits(self.inference_engine.worker.model_runner.model, len(self.tokenizer))

def sleep(self, *args, **kwargs):
async def _execute_method(self, method: str | bytes, *args, **kwargs):
if method == "init_worker":
return self._init_worker(*args, **kwargs)
elif method == "load_model":
return self._load_model(*args, **kwargs)
elif method == "sleep":
return await self.sleep(*args, **kwargs)
elif method == "wake_up":
return await self.wake_up(*args, **kwargs)
else:
return self.inference_engine.execute_method(method, *args, **kwargs)

# ==================== server mode public methods ====================

def get_zeromq_address(self):
return self.address

async def sleep(self, *args, **kwargs):
"""Offload model weights and discard kv cache."""
if self.is_sleep:
return
self.sharding_manager.__exit__(None, None, None)
self.is_sleep = True

def wake_up(self, *args, **kwargs):
async def wake_up(self, *args, **kwargs):
"""Load model weights and build kv cache."""
if not self.is_sleep:
return
self.sharding_manager.__enter__() # pylint: disable=C2801
self.is_sleep = False

def execute_method(self, method: str | bytes, *args, **kwargs):
if method == "init_worker":
return self.init_worker(*args, **kwargs)
elif method == "load_model":
return self.load_model(*args, **kwargs)
elif method == "sleep":
return self.sleep(*args, **kwargs)
elif method == "wake_up":
return self.wake_up(*args, **kwargs)
else:
return self.inference_engine.execute_method(method, *args, **kwargs)
async def generate(self, *args, **kwargs):
"""Generate sequence with token-in-token-out."""
raise NotImplementedError

async def chat_completion(self, json_request):
"""OpenAI chat completion API."""
raise NotImplementedError
Loading