Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions vllm_omni/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
_postprocess_messages,
_ToolParser,
)
from vllm.logger import init_logger

logger = init_logger(__name__)


class OmniAsyncMultiModalItemTracker(AsyncMultiModalItemTracker):
Expand Down Expand Up @@ -135,6 +138,48 @@ def _cleanup_file_sync(file_path: str) -> None:
await asyncio.to_thread(_cleanup_file_sync, temp_video_file_path)


def _ensure_system_prompt(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
) -> list[ChatCompletionMessageParam]:
"""
Ensure a system prompt exists for Qwen-Omni models to preserve precision of the model.
Args:
messages: List of chat messages
model_config: Model configuration

Returns:
Messages list with system prompt
"""
model_name = getattr(model_config, "model", "").lower()
hf_config = getattr(model_config, "hf_config", None)
architectures = getattr(hf_config, "architectures", []) if hf_config else []

is_qwen_omni = ("qwen" in model_name and "omni" in model_name) or any(
"qwen" in arch.lower() and "omni" in arch.lower() for arch in architectures
)

if not is_qwen_omni:
return messages

if messages and messages[0].get("role") == "system":
return messages

default_qwen_omni_system_prompt = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech."
)

system_message: ChatCompletionMessageParam = {
"role": "system",
"content": default_qwen_omni_system_prompt,
}

logger.info(f"injecting system prompt {default_qwen_omni_system_prompt} for Qwen-Omni model")
return [system_message] + list(messages)


def parse_chat_messages_futures(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
Expand All @@ -153,6 +198,9 @@ def parse_chat_messages_futures(
Tuple of (conversation, mm_future) where mm_future resolves to
(mm_data, mm_uuids) when awaited.
"""
# auto-inject system prompt for Qwen-Omni models if missing
messages = _ensure_system_prompt(messages, model_config)

conversation: list[ConversationMessage] = []
mm_tracker = OmniAsyncMultiModalItemTracker(model_config)

Expand Down
11 changes: 7 additions & 4 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor,
input_embeds = self.talker.embed_input_ids(input_ids)

span_len = input_ids.shape[0]
update_dict = {}
if span_len > 1:
# prefill
input_ids, input_embeds, update_dict = self.talker_preprocess_prefill(input_ids, input_embeds, **info_dict)
Expand All @@ -600,11 +601,12 @@ def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor,
update_dict["code_predictor_codes"] = code_predictor_codes
else:
# decode
if info_dict.get("num_processed_tokens", 0) < len(info_dict.get("thinker_input_ids", [])):
if not info_dict.get("decode_flag", False):
info_dict["num_processed_tokens"] = len(info_dict.get("thinker_input_ids", [])) + 1
update_dict["decode_flag"] = True

last_talker_hidden, text_step, update_dict = self.talker_preprocess_decode(
input_ids, input_embeds, **info_dict
input_ids, input_embeds, update_dict, **info_dict
)
update_dict["mtp_inputs"] = last_talker_hidden, text_step

Expand Down Expand Up @@ -874,8 +876,9 @@ def _thinker_decode_to_talker_decode(
thinker_embed = thinker_embed[start_index : start_index + 1].to(device)
return self.talker.text_projection(thinker_embed).to(device)

def talker_preprocess_decode(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, **info_dict: dict):
update_dict: dict[str, dict] = {}
def talker_preprocess_decode(
self, input_ids: torch.Tensor, input_embeds: torch.Tensor, update_dict: dict, **info_dict: dict
):
last_talker_hidden = None
text_step = None
try:
Expand Down