diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 17e025155a43..2bab9f1c97ab 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -848,6 +848,8 @@ def wait_for_kv_layer_from_connector(layer_name: str): return connector = get_kv_transfer_group() + if not connector.has_connector_metadata(): + return forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -865,6 +867,8 @@ def maybe_save_kv_layer_to_connector( return connector = get_kv_transfer_group() + if not connector.has_connector_metadata(): + return forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 354aa9a87183..f85eb414b222 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -204,11 +204,18 @@ def _get_connector_metadata(self) -> KVConnectorMetadata: Returns: ConnectorMetadata: the connector metadata. """ - # Should only be called while set to valid metadata. assert self._connector_metadata is not None return self._connector_metadata + def has_connector_metadata(self) -> bool: + """Check whether the connector metadata is currently set. + + Returns: + bool: True if connector metadata exists, False otherwise. + """ + return self._connector_metadata is not None + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """ Initialize with the KV caches. Useful for pre-registering the diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index d7bbf02c8367..c9d08e9b78ed 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -171,16 +171,22 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # We must override the base class method here because we need to bind # the metadata to each connector in the order of the connectors in the # MultiKVConnectorMetadata. + # + # Note: Call the base class method to ensure metadata is also set on the + # MultiConnector instance itself; otherwise, `has_connector_metadata()` will + # always return False. def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: assert isinstance(connector_metadata, MultiKVConnectorMetadata) if connector_metadata.extra_async_saves: self._extra_async_saves.update(connector_metadata.extra_async_saves) for c, cm in zip(self._connectors, connector_metadata.metadata): c.bind_connector_metadata(cm) + super().bind_connector_metadata(connector_metadata) def clear_connector_metadata(self) -> None: for c in self._connectors: c.clear_connector_metadata() + super().clear_connector_metadata() def shutdown(self): exception: Exception | None = None