Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,8 @@ def wait_for_kv_layer_from_connector(layer_name: str):
return

connector = get_kv_transfer_group()
if not connector.has_connector_metadata():
return

forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
Expand All @@ -865,6 +867,8 @@ def maybe_save_kv_layer_to_connector(
return

connector = get_kv_transfer_group()
if not connector.has_connector_metadata():
return

forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
Expand Down
9 changes: 8 additions & 1 deletion vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,18 @@ def _get_connector_metadata(self) -> KVConnectorMetadata:
Returns:
ConnectorMetadata: the connector metadata.
"""

# Should only be called while set to valid metadata.
assert self._connector_metadata is not None
return self._connector_metadata

def has_connector_metadata(self) -> bool:
"""Check whether the connector metadata is currently set.

Returns:
bool: True if connector metadata exists, False otherwise.
"""
return self._connector_metadata is not None

def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""
Initialize with the KV caches. Useful for pre-registering the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,22 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
# We must override the base class method here because we need to bind
# the metadata to each connector in the order of the connectors in the
# MultiKVConnectorMetadata.
#
# Note: Call the base class method to ensure metadata is also set on the
# MultiConnector instance itself; otherwise, `has_connector_metadata()` will
# always return False.
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
if connector_metadata.extra_async_saves:
self._extra_async_saves.update(connector_metadata.extra_async_saves)
for c, cm in zip(self._connectors, connector_metadata.metadata):
c.bind_connector_metadata(cm)
super().bind_connector_metadata(connector_metadata)

def clear_connector_metadata(self) -> None:
for c in self._connectors:
c.clear_connector_metadata()
super().clear_connector_metadata()

def shutdown(self):
exception: Exception | None = None
Expand Down