Skip to content
9 changes: 8 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,16 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:

# Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client = AsyncEngineRPCClient(port)
await async_engine_client.setup()

try:
while True:
try:
await async_engine_client.setup()
break
except TimeoutError as e:
if not rpc_server_process.is_alive():
raise RuntimeError("Server crashed") from e

yield async_engine_client
finally:
# Ensure rpc server process was terminated
Expand Down
19 changes: 15 additions & 4 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs

# Time to wait before checking it the server process is alive.
SERVER_START_TIMEOUT = 1000


class AsyncEngineRPCClient:

Expand Down Expand Up @@ -61,7 +64,9 @@ def socket(self):
socket.connect(self.path)
yield socket
finally:
socket.close()
# linger == 0 means discard unsent messages
# when the socket is closed
socket.close(linger=0)

async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
expected_type: Any,
Expand All @@ -85,14 +90,19 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,

return data

async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
error_message: str):
async def _send_one_way_rpc_request(self,
request: RPC_REQUEST_TYPE,
error_message: str,
timeout: Optional[int] = None):
"""Send one-way RPC request to trigger an action."""
with self.socket() as socket:
# Ping RPC Server with request.
await socket.send(cloudpickle.dumps(request))

# Await acknowledgement from RPCServer.
if timeout is not None and await socket.poll(timeout=timeout) == 0:
raise TimeoutError(f"server didn't reply within {timeout} ms")

response = cloudpickle.loads(await socket.recv())

if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
Expand All @@ -117,7 +127,8 @@ async def wait_for_server(self):

await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server.")
error_message="Unable to start RPC Server.",
timeout=SERVER_START_TIMEOUT)

async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server"""
Expand Down