2626- After inference, all the parameters that doesn't belong to this pp rank is freed.
2727"""
2828
29+ import asyncio
2930import getpass
3031import logging
3132import os
3233import pickle
3334import socket
34- import threading
3535from contextlib import contextmanager
3636from copy import deepcopy
3737from types import MethodType
4242import torch
4343import torch .distributed
4444import zmq
45+ import zmq .asyncio
4546from filelock import FileLock
4647from omegaconf import DictConfig , ListConfig , OmegaConf
4748from tensordict import TensorDict
@@ -446,12 +447,12 @@ def _init_zeromq(self) -> str:
446447 else :
447448 ip , port = self ._get_free_port ()
448449 address = f"tcp://{ ip } :{ port } "
449- context = zmq .Context ()
450+ context = zmq .asyncio . Context ()
450451 self .socket = context .socket (zmq .REP )
451452 self .socket .bind (address )
452453
453- self . loop_thread = threading . Thread ( target = self . _loop_forever )
454- self .loop_thread . start ( )
454+ loop = asyncio . get_running_loop ( )
455+ self .zmq_loop_task = loop . create_task ( self . _loop_forever () )
455456
456457 return address
457458
@@ -462,25 +463,22 @@ def _get_free_port(self):
462463 port = sock .getsockname ()[1 ]
463464 return ip , port
464465
465- def _loop_forever (self ):
466+ async def _loop_forever (self ):
466467 while True :
467- message = self .socket .recv ()
468+ message = await self .socket .recv ()
468469 method , args , kwargs = pickle .loads (message )
469- result = self .execute_method (method , * args , ** kwargs )
470- self .socket .send (pickle .dumps (result ))
470+ result = await self ._execute_method (method , * args , ** kwargs )
471+ await self .socket .send (pickle .dumps (result ))
471472
472- def get_zeromq_address (self ):
473- return self .address
474-
475- def init_worker (self , all_kwargs : list [dict [str , Any ]]):
473+ def _init_worker (self , all_kwargs : list [dict [str , Any ]]):
476474 """Initialize worker engine."""
477475 all_kwargs [0 ]["rank" ] = int (os .environ ["RANK" ])
478476 all_kwargs [0 ]["local_rank" ] = 0 if not ray_noset_visible_devices () else int (os .environ .get ("RAY_LOCAL_RANK" , 0 ))
479477 self .vllm_config = all_kwargs [0 ]["vllm_config" ]
480478 self .inference_engine = WorkerWrapperBase (vllm_config = self .vllm_config )
481479 self .inference_engine .init_worker (all_kwargs )
482480
483- def load_model (self , * args , ** kwargs ):
481+ def _load_model (self , * args , ** kwargs ):
484482 self .inference_engine .load_model (* args , ** kwargs )
485483
486484 # inference engine is initialized now, update sharding manager
@@ -489,28 +487,41 @@ def load_model(self, *args, **kwargs):
489487
490488 _monkey_patch_compute_logits (self .inference_engine .worker .model_runner .model , len (self .tokenizer ))
491489
492- def sleep (self , * args , ** kwargs ):
490+ async def _execute_method (self , method : str | bytes , * args , ** kwargs ):
491+ if method == "init_worker" :
492+ return self ._init_worker (* args , ** kwargs )
493+ elif method == "load_model" :
494+ return self ._load_model (* args , ** kwargs )
495+ elif method == "sleep" :
496+ return await self .sleep (* args , ** kwargs )
497+ elif method == "wake_up" :
498+ return await self .wake_up (* args , ** kwargs )
499+ else :
500+ return self .inference_engine .execute_method (method , * args , ** kwargs )
501+
502+ # ==================== server mode public methods ====================
503+
504+ def get_zeromq_address (self ):
505+ return self .address
506+
507+ async def sleep (self , * args , ** kwargs ):
493508 """Offload model weights and discard kv cache."""
494509 if self .is_sleep :
495510 return
496511 self .sharding_manager .__exit__ (None , None , None )
497512 self .is_sleep = True
498513
499- def wake_up (self , * args , ** kwargs ):
514+ async def wake_up (self , * args , ** kwargs ):
500515 """Load model weights and build kv cache."""
501516 if not self .is_sleep :
502517 return
503518 self .sharding_manager .__enter__ () # pylint: disable=C2801
504519 self .is_sleep = False
505520
506- def execute_method (self , method : str | bytes , * args , ** kwargs ):
507- if method == "init_worker" :
508- return self .init_worker (* args , ** kwargs )
509- elif method == "load_model" :
510- return self .load_model (* args , ** kwargs )
511- elif method == "sleep" :
512- return self .sleep (* args , ** kwargs )
513- elif method == "wake_up" :
514- return self .wake_up (* args , ** kwargs )
515- else :
516- return self .inference_engine .execute_method (method , * args , ** kwargs )
521+ async def generate (self , * args , ** kwargs ):
522+ """Generate sequence with token-in-token-out."""
523+ raise NotImplementedError
524+
525+ async def chat_completion (self , json_request ):
526+ """OpenAI chat completion API."""
527+ raise NotImplementedError
0 commit comments