99import torch
1010import torch .distributed as dist
1111from torch .distributed import ProcessGroup
12- from zmq import PUB , REP , REQ , SUB , SUBSCRIBE , Context # type: ignore
12+ from zmq import SUB , SUBSCRIBE , XPUB , XPUB_VERBOSE , Context # type: ignore
1313
1414import vllm .envs as envs
1515from vllm .logger import init_logger
@@ -153,9 +153,7 @@ class Handle:
153153
154154 buffer : Optional [ShmRingBuffer ] = None
155155 local_subscribe_port : Optional [int ] = None
156- local_sync_port : Optional [int ] = None
157156 remote_subscribe_port : Optional [int ] = None
158- remote_sync_port : Optional [int ] = None
159157
160158
161159class MessageQueue :
@@ -189,38 +187,36 @@ def __init__(
189187 self .buffer = ShmRingBuffer (n_local_reader , max_chunk_bytes ,
190188 max_chunks )
191189
192- self .local_socket = context .socket (PUB )
190+ # XPUB is very similar to PUB,
191+ # except that it can receive subscription messages
192+ # to confirm the number of subscribers
193+ self .local_socket = context .socket (XPUB )
194+ # set the verbose option so that we can receive every subscription
195+ # message. otherwise, we will only receive the first subscription
196+ # see http://api.zeromq.org/3-3:zmq-setsockopt for more details
197+ self .local_socket .setsockopt (XPUB_VERBOSE , True )
193198 local_subscribe_port = get_open_port ()
194199 self .local_socket .bind (f"tcp://*:{ local_subscribe_port } " )
195200
196- self .local_sync_socket = context .socket (REP )
197- local_sync_port = get_open_port ()
198- self .local_sync_socket .bind (f"tcp://*:{ local_sync_port } " )
199201 self .current_idx = 0
200202
201203 else :
202204 self .buffer = None # type: ignore
203205 local_subscribe_port = None
204- local_sync_port = None
205206 self .local_socket = None
206- self .local_sync_socket = None
207207 self .current_idx = - 1
208208
209209 if n_remote_reader > 0 :
210210 # for remote readers, we will:
211211 # create a publish-subscribe socket to communicate large data
212- self .remote_socket = context .socket (PUB )
212+ self .remote_socket = context .socket (XPUB )
213+ self .remote_socket .setsockopt (XPUB_VERBOSE , True )
213214 remote_subscribe_port = get_open_port ()
214215 self .remote_socket .bind (f"tcp://*:{ remote_subscribe_port } " )
215216
216- self .remote_sync_socket = context .socket (REP )
217- remote_sync_port = get_open_port ()
218- self .remote_sync_socket .bind (f"tcp://*:{ remote_sync_port } " )
219217 else :
220218 remote_subscribe_port = None
221- remote_sync_port = None
222219 self .remote_socket = None
223- self .remote_sync_socket = None
224220
225221 self ._is_writer = True
226222 self ._is_local_reader = False
@@ -233,9 +229,7 @@ def __init__(
233229 local_reader_ranks = local_reader_ranks ,
234230 buffer = self .buffer ,
235231 local_subscribe_port = local_subscribe_port ,
236- local_sync_port = local_sync_port ,
237232 remote_subscribe_port = remote_subscribe_port ,
238- remote_sync_port = remote_sync_port ,
239233 )
240234
241235 logger .info ("vLLM message queue communication handle: %s" , self .handle )
@@ -264,12 +258,7 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
264258 self .local_socket .connect (
265259 f"tcp://{ handle .connect_ip } :{ handle .local_subscribe_port } " )
266260
267- self .local_sync_socket = context .socket (REQ )
268- self .local_sync_socket .connect (
269- f"tcp://{ handle .connect_ip } :{ handle .local_sync_port } " )
270-
271261 self .remote_socket = None
272- self .remote_sync_socket = None
273262 else :
274263 self .buffer = None # type: ignore
275264 self .current_idx = - 1
@@ -278,17 +267,12 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
278267 self ._is_remote_reader = True
279268
280269 self .local_socket = None
281- self .local_sync_socket = None
282270
283271 self .remote_socket = context .socket (SUB )
284272 self .remote_socket .setsockopt_string (SUBSCRIBE , "" )
285273 self .remote_socket .connect (
286274 f"tcp://{ handle .connect_ip } :{ handle .remote_subscribe_port } " )
287275
288- self .remote_sync_socket = context .socket (REQ )
289- self .remote_sync_socket .connect (
290- f"tcp://{ handle .connect_ip } :{ handle .remote_sync_port } " )
291-
292276 return self
293277
294278 def wait_until_ready (self ):
@@ -300,29 +284,27 @@ def wait_until_ready(self):
300284
301285 # local readers
302286 for i in range (self .n_local_reader ):
303- recv = self .local_sync_socket .recv ()
304- assert recv == b"READY"
305- self .local_sync_socket .send (b"READY" )
287+ # wait for subscription messages from all local readers
288+ self .local_socket .recv ()
306289 if self .n_local_reader > 0 :
290+ # send a message to all local readers
291+ # to make sure the publish channel is working
307292 self .local_socket .send (b"READY" )
308293
309294 # remote readers
310295 for i in range (self .n_remote_reader ):
311- recv = self .remote_sync_socket .recv ()
312- assert recv == b"READY"
313- self .remote_sync_socket .send (b"READY" )
296+ # wait for subscription messages from all remote readers
297+ self .remote_socket .recv ()
314298 if self .n_remote_reader > 0 :
299+ # send a message to all remote readers
300+ # to make sure the publish channel is working
315301 self .remote_socket .send (b"READY" )
316302 elif self ._is_local_reader :
317- self .local_sync_socket .send (b"READY" )
318- recv = self .local_sync_socket .recv ()
319- assert recv == b"READY"
303+ # wait for the writer to send a message
320304 recv = self .local_socket .recv ()
321305 assert recv == b"READY"
322306 elif self ._is_remote_reader :
323- self .remote_sync_socket .send (b"READY" )
324- recv = self .remote_sync_socket .recv ()
325- assert recv == b"READY"
307+ # wait for the writer to send a message
326308 recv = self .remote_socket .recv ()
327309 assert recv == b"READY"
328310
0 commit comments