Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin
from vllm_omni.model_executor.models.output_templates import OmniOutput
from vllm_omni.model_executor.models.utils import add_prefix_to_loaded_weights, safe_tensor_reshape
from vllm_omni.utils.platform_utils import is_npu

# Special token IDs for Qwen3 Omni MoE
# Reference: https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json
Expand Down Expand Up @@ -326,11 +327,27 @@ def forward(
"capture_layer_indices": [0, int(accept_layer)],
"return_hidden_states": True,
}
if is_npu():
# TODO: remove this hack when NPU supports batched inputs properly
thinker_input_ids = input_ids[0] if input_ids is not None and _added_batch_dim else input_ids
thinker_positions = positions[0] if positions.ndim > 1 else positions
thinker_inputs_embeds = (
inputs_embeds[0] if inputs_embeds is not None and _added_batch_dim else inputs_embeds
)
else:
thinker_input_ids = input_ids
thinker_positions = positions[0] if positions.ndim > 1 else positions
thinker_inputs_embeds = inputs_embeds
# thinker_input_ids = input_ids
# thinker_positions = positions[0] if positions.ndim > 1 else positions
# thinker_inputs_embeds = inputs_embeds

# Run thinker
text_hidden_states, captured_layer_dict = self.thinker(
input_ids=input_ids,
positions=positions[0] if positions.ndim > 1 else positions,
input_ids=thinker_input_ids,
positions=thinker_positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
inputs_embeds=thinker_inputs_embeds,
**capture_kwargs,
**kwargs,
)
Expand Down
95 changes: 95 additions & 0 deletions vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
# Stage 0: Thinker (multimodal understanding + text generation)
# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)

# The following config has been verified on 5x A2/A3-64G NPUs.
stage_args:
- stage_id: 0
runtime:
devices: "0,1,2,3"
max_batch_size: 1
engine_args:
model_stage: thinker
model_arch: Qwen3OmniMoeForConditionalGeneration
worker_cls: vllm_omni.worker.npu.npu_ar_worker.NPUARWorker
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.6
enforce_eager: true
trust_remote_code: true
engine_output_type: latent # Output hidden states for talker
distributed_executor_backend: "mp"
enable_prefix_caching: false
hf_config_name: thinker_config
tensor_parallel_size: 4
final_output: true
final_output_type: text
is_comprehension: true
default_sampling_params:
temperature: 0.4
top_p: 0.9
top_k: 1
max_tokens: 2048
seed: 42
detokenize: True
repetition_penalty: 1.05

- stage_id: 1
runtime:
devices: "4"
max_batch_size: 1
engine_args:
model_stage: talker
model_arch: Qwen3OmniMoeForConditionalGeneration
worker_cls: vllm_omni.worker.npu.npu_ar_worker.NPUARWorker
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.2
enforce_eager: true
trust_remote_code: true
engine_output_type: latent # Output codec codes for code2wav
# tensor_parallel_size: 2
enable_prefix_caching: false
distributed_executor_backend: "mp"
hf_config_name: talker_config
engine_input_source: [0]
custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
# final_output: true
# final_output_type: text
default_sampling_params:
temperature: 0.9
top_k: 50
max_tokens: 4096
seed: 42
detokenize: False
repetition_penalty: 1.05
stop_token_ids: [2150]

- stage_id: 2
runtime:
devices: "0"
max_batch_size: 1
engine_args:
model_stage: code2wav
model_arch: Qwen3OmniMoeForConditionalGeneration
worker_cls: vllm_omni.worker.npu.npu_generation_worker.NPUGenerationWorker
scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
enforce_eager: true
trust_remote_code: true
enable_prefix_caching: false
engine_output_type: audio # Final output: audio waveform
gpu_memory_utilization: 0.1
distributed_executor_backend: "mp"
max_num_batched_tokens: 4096
hf_config_name: thinker_config
engine_input_source: [1]
custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
final_output: true
final_output_type: audio
default_sampling_params:
temperature: 0.0
top_p: 1.0
top_k: -1
max_tokens: 65536
seed: 42
detokenize: True
repetition_penalty: 1.1
28 changes: 19 additions & 9 deletions vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,20 @@

import torch
from vllm.inputs import TextPrompt
from vllm.platforms import current_platform

from vllm_omni.inputs.data import OmniTokensPrompt


def _compute_talker_prompt_ids_length(info):
def _compute_talker_prompt_ids_length(info, device: torch.device | str = "cuda") -> int:
im_start_token_id = 151644
system_token_id = 8948
user_token_id = 872
assistant_token_id = 77091

thinker_sequences = torch.tensor(info["thinker_sequences"]).unsqueeze(0).cuda().long() # [1, T]
input_ids = torch.tensor(info["thinker_input_ids"]).unsqueeze(0).cuda().long() # [1, T]
thinker_sequences = torch.tensor(info["thinker_sequences"], dtype=torch.long, device=device).unsqueeze(0) # [1, T]

input_ids = torch.tensor(info["thinker_input_ids"], dtype=torch.long, device=device).unsqueeze(0) # [1, T]

im_start_indexes = torch.cat(
[
Expand Down Expand Up @@ -82,26 +84,34 @@ def thinker2talker(
thinker_outputs = stage_list[source_stage_id].engine_outputs
talker_inputs = []

device = torch.device(current_platform.device_type)

# Process each thinker output
for i, thinker_output in enumerate(thinker_outputs):
output = thinker_output.outputs[0]
thinker_embeddings = output.multimodal_output["0"].float().clone().detach().cuda()
thinker_embeddings = output.multimodal_output["0"].detach().to(device=device, dtype=torch.float)

thinker_hidden_states = output.multimodal_output["24"].float().clone().detach().cuda()
thinker_hidden_states = output.multimodal_output["24"].detach().to(device=device, dtype=torch.float)
info = {
"thinker_embeddings": thinker_embeddings,
"thinker_hidden_states": thinker_hidden_states,
"thinker_sequences": thinker_output.prompt_token_ids
+ output.token_ids, # the thinker_sequences is the whole ids
"thinker_input_ids": thinker_output.prompt_token_ids,
# Provide thinker-side TTS token embeddings for talker projection
"tts_bos_embed": output.multimodal_output.get("tts_bos_embed").float().clone().detach().cuda(),
"tts_eos_embed": output.multimodal_output.get("tts_eos_embed").float().clone().detach().cuda(),
"tts_pad_embed": output.multimodal_output.get("tts_pad_embed").float().clone().detach().cuda(),
"tts_bos_embed": (
output.multimodal_output.get("tts_bos_embed").detach().to(device=device, dtype=torch.float)
),
"tts_eos_embed": (
output.multimodal_output.get("tts_eos_embed").detach().to(device=device, dtype=torch.float)
),
"tts_pad_embed": (
output.multimodal_output.get("tts_pad_embed").detach().to(device=device, dtype=torch.float)
),
}
talker_inputs.append(
OmniTokensPrompt(
prompt_token_ids=[0] * _compute_talker_prompt_ids_length(info),
prompt_token_ids=[0] * _compute_talker_prompt_ids_length(info, device=device),
additional_information=info,
multi_modal_data=None,
mm_processor_kwargs=None,
Expand Down