diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 2b1759f714..2dbc71c966 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -33,6 +33,7 @@ from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin from vllm_omni.model_executor.models.output_templates import OmniOutput from vllm_omni.model_executor.models.utils import add_prefix_to_loaded_weights, safe_tensor_reshape +from vllm_omni.utils.platform_utils import is_npu # Special token IDs for Qwen3 Omni MoE # Reference: https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json @@ -326,11 +327,27 @@ def forward( "capture_layer_indices": [0, int(accept_layer)], "return_hidden_states": True, } + if is_npu(): + # TODO: remove this hack when NPU supports batched inputs properly + thinker_input_ids = input_ids[0] if input_ids is not None and _added_batch_dim else input_ids + thinker_positions = positions[0] if positions.ndim > 1 else positions + thinker_inputs_embeds = ( + inputs_embeds[0] if inputs_embeds is not None and _added_batch_dim else inputs_embeds + ) + else: + thinker_input_ids = input_ids + thinker_positions = positions[0] if positions.ndim > 1 else positions + thinker_inputs_embeds = inputs_embeds + # thinker_input_ids = input_ids + # thinker_positions = positions[0] if positions.ndim > 1 else positions + # thinker_inputs_embeds = inputs_embeds + + # Run thinker text_hidden_states, captured_layer_dict = self.thinker( - input_ids=input_ids, - positions=positions[0] if positions.ndim > 1 else positions, + input_ids=thinker_input_ids, + positions=thinker_positions, intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, + inputs_embeds=thinker_inputs_embeds, **capture_kwargs, **kwargs, ) diff --git a/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml b/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml new file mode 100644 index 0000000000..5ebdddd180 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml @@ -0,0 +1,95 @@ +# Stage config for running Qwen3-Omni-MoE with 3-stage architecture +# Stage 0: Thinker (multimodal understanding + text generation) +# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes) +# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform) + +# The following config has been verified on 5x A2/A3-64G NPUs. +stage_args: + - stage_id: 0 + runtime: + devices: "0,1,2,3" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_cls: vllm_omni.worker.npu.npu_ar_worker.NPUARWorker + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.6 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # Output hidden states for talker + distributed_executor_backend: "mp" + enable_prefix_caching: false + hf_config_name: thinker_config + tensor_parallel_size: 4 + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + runtime: + devices: "4" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_cls: vllm_omni.worker.npu.npu_ar_worker.NPUARWorker + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.2 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # Output codec codes for code2wav + # tensor_parallel_size: 2 + enable_prefix_caching: false + distributed_executor_backend: "mp" + hf_config_name: talker_config + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker + # final_output: true + # final_output_type: text + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + detokenize: False + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 2 + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_cls: vllm_omni.worker.npu.npu_generation_worker.NPUGenerationWorker + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio # Final output: audio waveform + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 4096 + hf_config_name: thinker_config + engine_input_source: [1] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: True + repetition_penalty: 1.1 diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index cfa2d37ad7..6887fd969f 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -7,18 +7,20 @@ import torch from vllm.inputs import TextPrompt +from vllm.platforms import current_platform from vllm_omni.inputs.data import OmniTokensPrompt -def _compute_talker_prompt_ids_length(info): +def _compute_talker_prompt_ids_length(info, device: torch.device | str = "cuda") -> int: im_start_token_id = 151644 system_token_id = 8948 user_token_id = 872 assistant_token_id = 77091 - thinker_sequences = torch.tensor(info["thinker_sequences"]).unsqueeze(0).cuda().long() # [1, T] - input_ids = torch.tensor(info["thinker_input_ids"]).unsqueeze(0).cuda().long() # [1, T] + thinker_sequences = torch.tensor(info["thinker_sequences"], dtype=torch.long, device=device).unsqueeze(0) # [1, T] + + input_ids = torch.tensor(info["thinker_input_ids"], dtype=torch.long, device=device).unsqueeze(0) # [1, T] im_start_indexes = torch.cat( [ @@ -82,12 +84,14 @@ def thinker2talker( thinker_outputs = stage_list[source_stage_id].engine_outputs talker_inputs = [] + device = torch.device(current_platform.device_type) + # Process each thinker output for i, thinker_output in enumerate(thinker_outputs): output = thinker_output.outputs[0] - thinker_embeddings = output.multimodal_output["0"].float().clone().detach().cuda() + thinker_embeddings = output.multimodal_output["0"].detach().to(device=device, dtype=torch.float) - thinker_hidden_states = output.multimodal_output["24"].float().clone().detach().cuda() + thinker_hidden_states = output.multimodal_output["24"].detach().to(device=device, dtype=torch.float) info = { "thinker_embeddings": thinker_embeddings, "thinker_hidden_states": thinker_hidden_states, @@ -95,13 +99,19 @@ def thinker2talker( + output.token_ids, # the thinker_sequences is the whole ids "thinker_input_ids": thinker_output.prompt_token_ids, # Provide thinker-side TTS token embeddings for talker projection - "tts_bos_embed": output.multimodal_output.get("tts_bos_embed").float().clone().detach().cuda(), - "tts_eos_embed": output.multimodal_output.get("tts_eos_embed").float().clone().detach().cuda(), - "tts_pad_embed": output.multimodal_output.get("tts_pad_embed").float().clone().detach().cuda(), + "tts_bos_embed": ( + output.multimodal_output.get("tts_bos_embed").detach().to(device=device, dtype=torch.float) + ), + "tts_eos_embed": ( + output.multimodal_output.get("tts_eos_embed").detach().to(device=device, dtype=torch.float) + ), + "tts_pad_embed": ( + output.multimodal_output.get("tts_pad_embed").detach().to(device=device, dtype=torch.float) + ), } talker_inputs.append( OmniTokensPrompt( - prompt_token_ids=[0] * _compute_talker_prompt_ids_length(info), + prompt_token_ids=[0] * _compute_talker_prompt_ids_length(info, device=device), additional_information=info, multi_modal_data=None, mm_processor_kwargs=None,