Skip to content

Commit cd97007

Browse files
youkaichaoAlvant
authored andcommitted
[core][distributed] fix zmq hang (vllm-project#6759)
Signed-off-by: Alvant <[email protected]>
1 parent b3502a4 commit cd97007

File tree

2 files changed

+23
-41
lines changed

2 files changed

+23
-41
lines changed

vllm/connections.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Mapping, Optional
2+
from typing import Mapping, MutableMapping, Optional
33
from urllib.parse import urlparse
44

55
import aiohttp
@@ -40,7 +40,7 @@ def _validate_http_url(self, url: str):
4040
raise ValueError("Invalid HTTP URL: A valid HTTP URL "
4141
"must have scheme 'http' or 'https'.")
4242

43-
def _headers(self, **extras: str) -> Mapping[str, str]:
43+
def _headers(self, **extras: str) -> MutableMapping[str, str]:
4444
return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras}
4545

4646
def get_response(

vllm/distributed/device_communicators/shm_broadcast.py

Lines changed: 21 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
import torch.distributed as dist
1111
from torch.distributed import ProcessGroup
12-
from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore
12+
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
1313

1414
import vllm.envs as envs
1515
from vllm.logger import init_logger
@@ -153,9 +153,7 @@ class Handle:
153153

154154
buffer: Optional[ShmRingBuffer] = None
155155
local_subscribe_port: Optional[int] = None
156-
local_sync_port: Optional[int] = None
157156
remote_subscribe_port: Optional[int] = None
158-
remote_sync_port: Optional[int] = None
159157

160158

161159
class MessageQueue:
@@ -189,38 +187,36 @@ def __init__(
189187
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
190188
max_chunks)
191189

192-
self.local_socket = context.socket(PUB)
190+
# XPUB is very similar to PUB,
191+
# except that it can receive subscription messages
192+
# to confirm the number of subscribers
193+
self.local_socket = context.socket(XPUB)
194+
# set the verbose option so that we can receive every subscription
195+
# message. otherwise, we will only receive the first subscription
196+
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
197+
self.local_socket.setsockopt(XPUB_VERBOSE, True)
193198
local_subscribe_port = get_open_port()
194199
self.local_socket.bind(f"tcp://*:{local_subscribe_port}")
195200

196-
self.local_sync_socket = context.socket(REP)
197-
local_sync_port = get_open_port()
198-
self.local_sync_socket.bind(f"tcp://*:{local_sync_port}")
199201
self.current_idx = 0
200202

201203
else:
202204
self.buffer = None # type: ignore
203205
local_subscribe_port = None
204-
local_sync_port = None
205206
self.local_socket = None
206-
self.local_sync_socket = None
207207
self.current_idx = -1
208208

209209
if n_remote_reader > 0:
210210
# for remote readers, we will:
211211
# create a publish-subscribe socket to communicate large data
212-
self.remote_socket = context.socket(PUB)
212+
self.remote_socket = context.socket(XPUB)
213+
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
213214
remote_subscribe_port = get_open_port()
214215
self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")
215216

216-
self.remote_sync_socket = context.socket(REP)
217-
remote_sync_port = get_open_port()
218-
self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}")
219217
else:
220218
remote_subscribe_port = None
221-
remote_sync_port = None
222219
self.remote_socket = None
223-
self.remote_sync_socket = None
224220

225221
self._is_writer = True
226222
self._is_local_reader = False
@@ -233,9 +229,7 @@ def __init__(
233229
local_reader_ranks=local_reader_ranks,
234230
buffer=self.buffer,
235231
local_subscribe_port=local_subscribe_port,
236-
local_sync_port=local_sync_port,
237232
remote_subscribe_port=remote_subscribe_port,
238-
remote_sync_port=remote_sync_port,
239233
)
240234

241235
logger.info("vLLM message queue communication handle: %s", self.handle)
@@ -264,12 +258,7 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
264258
self.local_socket.connect(
265259
f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")
266260

267-
self.local_sync_socket = context.socket(REQ)
268-
self.local_sync_socket.connect(
269-
f"tcp://{handle.connect_ip}:{handle.local_sync_port}")
270-
271261
self.remote_socket = None
272-
self.remote_sync_socket = None
273262
else:
274263
self.buffer = None # type: ignore
275264
self.current_idx = -1
@@ -278,17 +267,12 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
278267
self._is_remote_reader = True
279268

280269
self.local_socket = None
281-
self.local_sync_socket = None
282270

283271
self.remote_socket = context.socket(SUB)
284272
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
285273
self.remote_socket.connect(
286274
f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")
287275

288-
self.remote_sync_socket = context.socket(REQ)
289-
self.remote_sync_socket.connect(
290-
f"tcp://{handle.connect_ip}:{handle.remote_sync_port}")
291-
292276
return self
293277

294278
def wait_until_ready(self):
@@ -300,29 +284,27 @@ def wait_until_ready(self):
300284

301285
# local readers
302286
for i in range(self.n_local_reader):
303-
recv = self.local_sync_socket.recv()
304-
assert recv == b"READY"
305-
self.local_sync_socket.send(b"READY")
287+
# wait for subscription messages from all local readers
288+
self.local_socket.recv()
306289
if self.n_local_reader > 0:
290+
# send a message to all local readers
291+
# to make sure the publish channel is working
307292
self.local_socket.send(b"READY")
308293

309294
# remote readers
310295
for i in range(self.n_remote_reader):
311-
recv = self.remote_sync_socket.recv()
312-
assert recv == b"READY"
313-
self.remote_sync_socket.send(b"READY")
296+
# wait for subscription messages from all remote readers
297+
self.remote_socket.recv()
314298
if self.n_remote_reader > 0:
299+
# send a message to all remote readers
300+
# to make sure the publish channel is working
315301
self.remote_socket.send(b"READY")
316302
elif self._is_local_reader:
317-
self.local_sync_socket.send(b"READY")
318-
recv = self.local_sync_socket.recv()
319-
assert recv == b"READY"
303+
# wait for the writer to send a message
320304
recv = self.local_socket.recv()
321305
assert recv == b"READY"
322306
elif self._is_remote_reader:
323-
self.remote_sync_socket.send(b"READY")
324-
recv = self.remote_sync_socket.recv()
325-
assert recv == b"READY"
307+
# wait for the writer to send a message
326308
recv = self.remote_socket.recv()
327309
assert recv == b"READY"
328310

0 commit comments

Comments
 (0)