Skip to content

[RFC]: Fix for NIXL metadata exchange issue when the worker is multi-node #25981

@GuanLuo

Description

@GuanLuo

Motivation.

This issue has been long identified and there are several PRs trying to address it. But those PRs don't get pushed to the finish line and would like to start yet another discussion to summarize the issue and propose a fix. Note that this fix doesn't intend to fix the other parallelism issues around NIXL (i.e. #22430), but only focus on the metadata exchange:

The issue is described in #19080 which is an attempt to fix the issue. In summary, the NIXL handshake is worker-to-worker and the decode work will try to reach the prefill worker based on the "remote host/port" in connector metadata. However, the "remote host" is the address of the head node, and therefore, for the prefill workers on the other nodes, the correct host address is not passed to decode worker and thus the handshake will fail. So we want to

  • centralize the handshake metadata from all workers to the head node engine, so that any downstream worker can obtain desired metadata from the single "remote host" pointing to the head node.

#19080 addressed this but it encounters certain edge cases at the time and suggest that the centralization should be done as a separate step from the NIXL connector initialization. One of the edge case is that the rank 0 worker may not be initialized on the head node so the worker can't make assumption on whether it should nominate itself to be the metadata exchange listener.

To address this, #22274 shows a good approach to introduce a RPC to collect the metadata from all workers during engine initialization and later creates the listener for hosting the metadata. However #22274 creates the listener outside of vLLM engine which results in long propagation chain and implementation leak (the HTTP server needs to know about NIXL connector internals).

This proposal takes the idea from #22274 and tries to limit the change to be within EngineCore and NixlConnector (still use zmq for metadata exchange).

Proposed Change.

Add RPC get_kv_connector_handshake_metadata() in executor to collect metadata from all workers as shown in #22274

# Collect and store KV connector xfer metadata from workers
# (after KV cache registration)
self.xfer_handshake_metadata = (
    self.model_executor.get_kv_connector_handshake_metadata())

Instead of propagating it all the way up, feed it back to kv_cache_config for scheduler initialization. In this way, the metadata will be passed to NixlConnectorScheduler and the listener can be managed by the connector scheduler as proposed here.

With this modification, the user interaction remains the same and the NIXL handshake behavior is contained within the NIXLConnector (similar to #19080).

Feedback Period.

No response

CC List.

No response

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions