Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ stage_args:
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.8
enforce_eager: false
enforce_eager: true # haven't supported talker ACL graph on NPU
trust_remote_code: true
enable_prefix_caching: false
engine_output_type: latent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ stage_args:
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.2
enforce_eager: false
enforce_eager: true # haven't supported talker ACL graph on NPU
trust_remote_code: true
engine_output_type: latent # Output codec codes for code2wav
# tensor_parallel_size: 2
Expand Down
197 changes: 10 additions & 187 deletions vllm_omni/platforms/npu/worker/npu_ar_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from vllm_ascend.ops.rotary_embedding import update_cos_sin
from vllm_ascend.utils import ProfileExecuteDuration

from vllm_omni.core.sched.omni_ar_scheduler import KVCacheTransferData
from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager
from vllm_omni.outputs import OmniModelRunnerOutput
from vllm_omni.platforms.npu.worker.npu_model_runner import OmniNPUModelRunner

Expand Down Expand Up @@ -65,7 +65,8 @@ def __init__(self, *args, **kwargs):
# each model stage has their own hidden size
self.hidden_size = self.model_config.hf_text_config.hidden_size
self.inputs_embeds = self._make_buffer(self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False)
self.omni_connector = None
# Initialize KV cache manager (preserve vllm_config fallback behavior)
self.kv_transfer_manager = OmniKVTransferManager.from_vllm_config(self.vllm_config, self.model_config)

def _make_buffer(self, *size, dtype, numpy=True):
# Prevent ray from pinning the buffer due to large size
Expand All @@ -92,7 +93,13 @@ def execute_model(

# -------------------------------------- Omni-new -------------------------------------------------
# [Omni] Handle KV transfer BEFORE updating states (which removes finished requests)
self.kv_extracted_req_ids = self._handle_finished_requests_kv_transfer(scheduler_output)
self.kv_extracted_req_ids = self.kv_transfer_manager.handle_finished_requests_kv_transfer(
finished_reqs=getattr(scheduler_output, "finished_requests_needing_kv_transfer", {}),
kv_caches=self.kv_caches,
block_size=self.cache_config.block_size,
cache_dtype=str(self.cache_config.cache_dtype),
request_id_resolver=self._resolve_global_request_id,
)
# -------------------------------------- Omni-new -------------------------------------------------

with ProfileExecuteDuration().capture_async("prepare input"):
Expand Down Expand Up @@ -499,161 +506,6 @@ def _generate_process_reqs_hidden_states(self, num_input_tokens,
return hidden_states if self.pcp_size == 1 else self.pcp_manager.get_restore_hidden_states(
hidden_states)

def _handle_finished_requests_kv_transfer(self, scheduler_output: SchedulerOutput) -> list[str]:
"""Handle KV cache transfer for finished requests.

Returns list of request IDs that were processed (for Scheduler to free blocks).
"""
finished_reqs = getattr(scheduler_output, "finished_requests_needing_kv_transfer", {})
if not finished_reqs:
return []

logger.debug(f"Processing KV transfer for {len(finished_reqs)} requests")

extracted_ids = []
for req_id, data in finished_reqs.items():
try:
seq_len = data.get("seq_len", 0)
block_ids = data.get("block_ids", [])
if not block_ids:
logger.warning(f"Request {req_id} has no block IDs, skipping")
continue

# Extract KV cache from GPU blocks -> CPU tensors
kv_data = self._extract_kv_cache(req_id, block_ids, seq_len)
if kv_data:
# Transfer to downstream stage via connector
self._transfer_kv_cache(kv_data)

except Exception as e:
logger.error(f"Failed KV transfer for {req_id}: {e}")
finally:
extracted_ids.append(req_id)

return extracted_ids

def _extract_kv_cache(self, req_id: str, block_ids: list[int], seq_len: int) -> KVCacheTransferData | None:
"""Extract KV cache from GPU blocks for a single request."""
num_layers = len(self.kv_caches)
key_cache = [None] * num_layers
value_cache = [None] * num_layers

for layer_idx, kv_tensor in enumerate(self.kv_caches):
# Validate block IDs
max_block = kv_tensor.shape[1] - 1
valid_ids = [bid for bid in block_ids if 0 <= bid <= max_block]
if not valid_ids:
continue

# Extract and reshape: [2, n_blocks, block_size, n_heads, head_dim]
# -> [2, seq_len, n_heads, head_dim]
selected = kv_tensor[:, valid_ids] # [2, n_valid, block_size, n_heads, head_dim]
n_kv, n_blks, blk_sz, n_heads, d_head = selected.shape
flat = selected.reshape(n_kv, n_blks * blk_sz, n_heads, d_head)
if seq_len < flat.shape[1]:
flat = flat[:, :seq_len]

# Move to CPU
flat_cpu = flat.detach().cpu().contiguous()
key_cache[layer_idx] = flat_cpu[0]
value_cache[layer_idx] = flat_cpu[1]

if not any(k is not None for k in key_cache):
return None

return KVCacheTransferData(
request_id=req_id,
layer_blocks={"key_cache": key_cache, "value_cache": value_cache},
block_ids=block_ids,
metadata={
"block_size": self.cache_config.block_size,
"num_layers": num_layers,
"dtype": str(self.cache_config.cache_dtype),
"seq_len": seq_len,
},
)

def _transfer_kv_cache(self, kv_data: KVCacheTransferData) -> None:
"""Transfer KV cache data to downstream stage via OmniConnector."""
connector = self._get_or_create_connector()
if not connector:
return

# Resolve global request ID if available
transfer_req_id = self._resolve_global_request_id(kv_data.request_id)
from_stage, to_stage = self._detect_transfer_stages()

# Prepare data and transfer with retry
data_dict = kv_data.to_dict()
data_dict["request_id"] = transfer_req_id

success, size, _ = self._transfer_with_retry(
connector, from_stage, to_stage, f"kv_cache_{transfer_req_id}", data_dict
)

if success:
logger.info(f"KV transfer OK: {transfer_req_id}, {size} bytes")
else:
logger.error(f"KV transfer FAILED: {transfer_req_id}")

def _get_or_create_connector(self) -> Any | None:
"""Get existing connector or create one from config."""
if self.omni_connector:
return self.omni_connector

from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory
from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec

config = self._get_omni_connector_config()
if not config or not isinstance(config, dict):
logger.warning("No valid OmniConnector config found")
return None

c_type = config.get("type")
if not c_type:
logger.error("OmniConnector config missing 'type' field")
return None

c_extra = {k: v for k, v in config.items() if k != "type"}
self.omni_connector = OmniConnectorFactory.create_connector(ConnectorSpec(name=c_type, extra=c_extra))
return self.omni_connector

def _get_omni_connector_config(self) -> dict[str, Any] | None:
"""Get OmniConnector configuration from model config."""
# Primary: omni_kv_config from YAML
omni_kv = getattr(self.model_config, "omni_kv_config", None)
if isinstance(omni_kv, dict):
cfg = omni_kv.get("connector_config")
if isinstance(cfg, dict) and cfg:
return cfg

# Fallback: kv_transfer_config
kv_cfg = getattr(self.vllm_config, "kv_transfer_config", None)
if kv_cfg:
direct = getattr(kv_cfg, "omni_connector_config", None)
if isinstance(direct, dict) and direct:
return direct
extra = getattr(kv_cfg, "kv_connector_extra_config", None)
if isinstance(extra, dict):
omni = extra.get("omni_connector_config")
if isinstance(omni, dict) and omni:
return omni

return None

def _detect_transfer_stages(self) -> tuple[str, str]:
"""Detect source and target stages for KV transfer."""
omni_kv = getattr(self.model_config, "omni_kv_config", None)
if isinstance(omni_kv, dict):
from_s = omni_kv.get("omni_from_stage")
to_s = omni_kv.get("omni_to_stage")
if from_s and to_s:
return str(from_s), str(to_s)

raise ValueError(
"KV transfer stages not configured. Please set 'omni_from_stage' and 'omni_to_stage' in omni_kv_config."
)

def _resolve_global_request_id(self, req_id: str) -> str:
"""Resolve global request ID from request state."""
req_state = self.requests.get(req_id)
Expand All @@ -669,32 +521,3 @@ def _resolve_global_request_id(self, req_id: str) -> str:
return global_id.decode("utf-8")
return str(global_id)
return req_id

def _transfer_with_retry(
self,
connector: Any,
from_stage: str,
to_stage: str,
request_id: str,
data: dict[str, Any],
max_retries: int = 3,
) -> tuple[bool, int, dict[str, Any] | None]:
"""Transfer data with retry and exponential backoff."""
import time

for attempt in range(max_retries):
try:
put_key = f"omni_{from_stage}_to_{to_stage}_{request_id}"
success, size, metadata = connector.put(
from_stage=from_stage, to_stage=to_stage, put_key=put_key, data=data
)
if success:
return success, size, metadata
logger.warning(f"Transfer attempt {attempt + 1} failed for {request_id}")
except Exception as e:
logger.warning(f"Transfer attempt {attempt + 1} exception: {e}")

if attempt < max_retries - 1:
time.sleep(0.1 * (2**attempt))

return False, 0, None
21 changes: 13 additions & 8 deletions vllm_omni/platforms/npu/worker/npu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ def load_model(self, *args, **kwargs) -> None:
self.model.talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
)
hidden_size = self.model_config.hf_config.talker_config.text_config.hidden_size
self.talker_mtp_input_ids = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
max_batch_size = max(self.max_num_reqs, self.compilation_config.max_cudagraph_capture_size)
self.talker_mtp_input_ids = self._make_buffer(max_batch_size, dtype=torch.int32)
self.talker_mtp_inputs_embeds = self._make_buffer(
self.max_num_reqs, hidden_size, dtype=self.dtype, numpy=False
max_batch_size, hidden_size, dtype=self.dtype, numpy=False
)
self.last_talker_hidden = self._make_buffer(self.max_num_reqs, hidden_size, dtype=self.dtype, numpy=False)
self.text_step = self._make_buffer(self.max_num_reqs, hidden_size, dtype=self.dtype, numpy=False)
self.last_talker_hidden = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False)
self.text_step = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False)

def _init_mrope_positions(self, req_state: CachedRequestState):
image_grid_thw = []
Expand Down Expand Up @@ -590,12 +591,16 @@ def dummy_drafter_compute_logits(hidden_states):
model_instance=self.model,
):
if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"):
num_tokens_padded_talker_mtp = num_tokens_padded
if num_tokens_padded_talker_mtp == self.max_num_tokens:
num_tokens_padded_talker_mtp = self.talker_mtp_input_ids.gpu.shape[0]
hidden_states = self.talker_mtp(
self.talker_mtp_input_ids.gpu[:num_tokens_padded],
self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded],
self.last_talker_hidden.gpu[:num_tokens_padded],
self.text_step.gpu[:num_tokens_padded],
self.talker_mtp_input_ids.gpu[:num_tokens_padded_talker_mtp],
self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded_talker_mtp],
self.last_talker_hidden.gpu[:num_tokens_padded_talker_mtp],
self.text_step.gpu[:num_tokens_padded_talker_mtp],
)
self.compilation_config.cache_dir = None
hidden_states = self._generate_dummy_run_hidden_states(
input_ids, positions, num_tokens_padded, intermediate_tensors, inputs_embeds
)
Expand Down