diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 08a652762f4..38b7673e35b 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -1341,7 +1341,10 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in return req_list + # ==================== server mode public methods ==================== + async def chat_completion(self, json_request): + """OpenAI chat completion API.""" assert self._tp_rank == 0, "only called in tp rank 0" _input_ids = None _attention_mask = None @@ -1419,19 +1422,21 @@ async def chat_completion(self, json_request): async def generate( self, prompt_ids: torch.Tensor, sampling_params: dict[str, Any], request_id: str ) -> torch.Tensor: + """Generate sequence with token-in-token-out.""" request_sampling_params = self.sampling_params.copy() request_sampling_params.update(sampling_params) output = await self._handle_engine_generate(prompt_ids, request_sampling_params) return output["output_ids"] async def wake_up(self): + """Load model weights and build kv cache.""" if not self.is_sleep: return await self.sharding_manager.wake_up() # pylint: disable=C2801 self.is_sleep = False - # this function is left for uniform train-inference resharding async def sleep(self): + """Offload model weights and discard kv cache.""" if self.is_sleep: return await self.sharding_manager.sleep() diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index a7dbc0b9434..69ec72b8233 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -26,12 +26,12 @@ - After inference, all the parameters that doesn't belong to this pp rank is freed. """ +import asyncio import getpass import logging import os import pickle import socket -import threading from contextlib import contextmanager from copy import deepcopy from types import MethodType @@ -42,6 +42,7 @@ import torch import torch.distributed import zmq +import zmq.asyncio from filelock import FileLock from omegaconf import DictConfig, OmegaConf from tensordict import TensorDict @@ -431,12 +432,12 @@ def _init_zeromq(self) -> str: else: ip, port = self._get_free_port() address = f"tcp://{ip}:{port}" - context = zmq.Context() + context = zmq.asyncio.Context() self.socket = context.socket(zmq.REP) self.socket.bind(address) - self.loop_thread = threading.Thread(target=self._loop_forever) - self.loop_thread.start() + loop = asyncio.get_running_loop() + self.zmq_loop_task = loop.create_task(self._loop_forever()) return address @@ -447,17 +448,14 @@ def _get_free_port(self): port = sock.getsockname()[1] return ip, port - def _loop_forever(self): + async def _loop_forever(self): while True: - message = self.socket.recv() + message = await self.socket.recv() method, args, kwargs = pickle.loads(message) - result = self.execute_method(method, *args, **kwargs) - self.socket.send(pickle.dumps(result)) + result = await self._execute_method(method, *args, **kwargs) + await self.socket.send(pickle.dumps(result)) - def get_zeromq_address(self): - return self.address - - def init_worker(self, all_kwargs: list[dict[str, Any]]): + def _init_worker(self, all_kwargs: list[dict[str, Any]]): """Initialize worker engine.""" all_kwargs[0]["rank"] = int(os.environ["RANK"]) all_kwargs[0]["local_rank"] = 0 if not ray_noset_visible_devices() else int(os.environ.get("RAY_LOCAL_RANK", 0)) @@ -465,7 +463,7 @@ def init_worker(self, all_kwargs: list[dict[str, Any]]): self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config) self.inference_engine.init_worker(all_kwargs) - def load_model(self, *args, **kwargs): + def _load_model(self, *args, **kwargs): self.inference_engine.load_model(*args, **kwargs) # inference engine is initialized now, update sharding manager @@ -474,28 +472,41 @@ def load_model(self, *args, **kwargs): _monkey_patch_compute_logits(self.inference_engine.worker.model_runner.model, len(self.tokenizer)) - def sleep(self, *args, **kwargs): + async def _execute_method(self, method: str | bytes, *args, **kwargs): + if method == "init_worker": + return self._init_worker(*args, **kwargs) + elif method == "load_model": + return self._load_model(*args, **kwargs) + elif method == "sleep": + return await self.sleep(*args, **kwargs) + elif method == "wake_up": + return await self.wake_up(*args, **kwargs) + else: + return self.inference_engine.execute_method(method, *args, **kwargs) + + # ==================== server mode public methods ==================== + + def get_zeromq_address(self): + return self.address + + async def sleep(self, *args, **kwargs): """Offload model weights and discard kv cache.""" if self.is_sleep: return self.sharding_manager.__exit__(None, None, None) self.is_sleep = True - def wake_up(self, *args, **kwargs): + async def wake_up(self, *args, **kwargs): """Load model weights and build kv cache.""" if not self.is_sleep: return self.sharding_manager.__enter__() # pylint: disable=C2801 self.is_sleep = False - def execute_method(self, method: str | bytes, *args, **kwargs): - if method == "init_worker": - return self.init_worker(*args, **kwargs) - elif method == "load_model": - return self.load_model(*args, **kwargs) - elif method == "sleep": - return self.sleep(*args, **kwargs) - elif method == "wake_up": - return self.wake_up(*args, **kwargs) - else: - return self.inference_engine.execute_method(method, *args, **kwargs) + async def generate(self, *args, **kwargs): + """Generate sequence with token-in-token-out.""" + raise NotImplementedError + + async def chat_completion(self, json_request): + """OpenAI chat completion API.""" + raise NotImplementedError