From d351c9849816f5be6d6891125fc03444ae0a3425 Mon Sep 17 00:00:00 2001 From: princepride Date: Mon, 9 Feb 2026 03:23:28 -0800 Subject: [PATCH 1/6] replace some layers to vllm version Signed-off-by: princepride --- .../models/bagel/bagel_transformer.py | 124 +++++++++++++++--- 1 file changed, 107 insertions(+), 17 deletions(-) diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index 6ee81a4fd4..95e6901fd8 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -8,6 +8,7 @@ # available at https://github.com/huggingface/transformers/blob/main/LICENSE. import math +from collections.abc import Iterable from dataclasses import dataclass from typing import Any @@ -18,12 +19,16 @@ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config from transformers.models.qwen2.modeling_qwen2 import ( Qwen2Attention, - Qwen2MLP, Qwen2PreTrainedModel, - Qwen2RMSNorm, - Qwen2RotaryEmbedding, ) from transformers.utils import ModelOutput +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.transformers_utils.configs.bagel import BagelConfig from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -66,6 +71,70 @@ def forward(self, x): return self.fc2(self.act(self.fc1(x))) +class BagelRotaryEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + dim = config.hidden_size // config.num_attention_heads + inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + """Generate cos/sin embeddings for given position ids. + + Args: + x: Input tensor (only used for dtype inference). + position_ids: Position indices, shape (batch_size, seq_len). + + Returns: + cos, sin: Rotary embeddings, each of shape (batch_size, seq_len, dim). + """ + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class BagelMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str = "silu", + ) -> None: + super().__init__() + self.gate_proj = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=False, + gather_output=False, + ) + self.up_proj = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=False, + gather_output=False, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. Only silu is supported.") + self.act_fn = nn.SiLU() + + def forward(self, x): + gate, _ = self.gate_proj(x) + up, _ = self.up_proj(x) + x = self.act_fn(gate) * up + x, _ = self.down_proj(x) + return x + + torch._dynamo.config.cache_size_limit = 512 torch._dynamo.config.accumulated_cache_size_limit = 4096 flex_attention = torch.compile(flex_attention) @@ -160,10 +229,10 @@ class BaseNavitOutputWithPast(ModelOutput): class PackedAttentionMoT(Qwen2Attention): def __init__(self, config, layer_idx: int | None = None): super().__init__(config, layer_idx) - self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.q_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.q_norm_moe_gen = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm_moe_gen = RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads @@ -305,12 +374,12 @@ def __init__( self.self_attn = attn_module(config, layer_idx) - self.mlp = Qwen2MLP(config) - self.mlp_moe_gen = Qwen2MLP(config) - self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.input_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = BagelMLP(config.hidden_size, config.intermediate_size, config.hidden_act) + self.mlp_moe_gen = BagelMLP(config.hidden_size, config.intermediate_size, config.hidden_act) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm_moe_gen = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm_moe_gen = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -391,7 +460,7 @@ def __init__(self, config): self.vocab_size = config.vocab_size self.use_moe = "Mo" in config.layer_module - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList( [ Qwen2MoTDecoderLayer(config, layer_idx, attn_module=PackedAttentionMoT) @@ -399,10 +468,10 @@ def __init__(self, config): ] ) - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.use_moe: - self.norm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen2RotaryEmbedding(config=config) + self.norm_moe_gen = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = BagelRotaryEmbedding(config=config) # Initialize weights and apply final processing self.post_init() @@ -528,6 +597,27 @@ def forward( return outputs + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights for vLLM parallel layers. + + Parameter names (gate_proj, up_proj, down_proj, etc.) are kept + identical to the HF checkpoint, so no name remapping is needed. + vLLM parallel layers attach a ``weight_loader`` to each parameter + that handles TP sharding automatically. + """ + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + param = params_dict.get(name) + if param is None: + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): grid_h = np.arange(grid_size, dtype=np.float32) From bbf124290ff1385e93896729326bf2698317896c Mon Sep 17 00:00:00 2001 From: princepride Date: Mon, 9 Feb 2026 06:31:39 -0800 Subject: [PATCH 2/6] fix some bug Signed-off-by: princepride --- vllm_omni/diffusion/models/bagel/pipeline_bagel.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index bdb9f1f5c3..b6955390f6 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -402,7 +402,7 @@ def vae_transforms(img): ) # Fail fast with a clear error instead of CUDA gather OOB. max_tid = int(generation_input["packed_text_ids"].max().item()) - emb_n = int(self.language_model.model.embed_tokens.weight.shape[0]) + emb_n = int(self.language_model.vocab_size) if max_tid >= emb_n: raise ValueError( "Tokenizer/model vocab mismatch: max token id " @@ -438,7 +438,7 @@ def vae_transforms(img): ) # Fail fast for special tokens used by the image path as well. max_tid_img = int(generation_input["packed_text_ids"].max().item()) - emb_n = int(self.language_model.model.embed_tokens.weight.shape[0]) + emb_n = int(self.language_model.vocab_size) if max_tid_img >= emb_n: raise ValueError( "Tokenizer/model vocab mismatch (image path): max token id " @@ -483,6 +483,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: allowed = set(state.keys()) shapes = {k: tuple(v.shape) for k, v in state.items()} + tp_aware_params = {name for name, p in self.named_parameters() if hasattr(p, "weight_loader")} + def _normalize_name(name: str) -> str: # Common wrappers/prefixes in checkpoints. for pfx in ("module.", "model."): @@ -536,7 +538,7 @@ def _filtered_weights(): for cand in _iter_candidate_names(name): if cand in allowed: # Only accept if tensor shape matches target param/buffer shape. - if tuple(tensor.shape) == shapes.get(cand): + if tuple(tensor.shape) == shapes.get(cand) or cand in tp_aware_params: picked = cand break else: From ae3cc0f4ab9050f7de7a6494ff469e36a4c4e399 Mon Sep 17 00:00:00 2001 From: princepride Date: Mon, 9 Feb 2026 07:01:26 -0800 Subject: [PATCH 3/6] fix some bug Signed-off-by: princepride --- .../models/bagel/bagel_transformer.py | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index 95e6901fd8..3d454a5143 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -72,10 +72,31 @@ def forward(self, x): class BagelRotaryEmbedding(nn.Module): + """Standalone rotary embedding that generates cos/sin from position ids. + + Replaces HuggingFace's Qwen2RotaryEmbedding while preserving full + ``rope_scaling`` support. When ``config.rope_scaling`` is set (e.g. + linear, dynamic-NTK, YaRN, …), we delegate the ``inv_freq`` / + ``attention_scaling`` computation to HF's ``ROPE_INIT_FUNCTIONS`` so + that the frequency basis and scaling factor are identical to the + original checkpoint. This module has no learnable parameters. + """ + def __init__(self, config): super().__init__() - dim = config.hidden_size // config.num_attention_heads - inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + if config.rope_scaling is not None: + # Delegate to HF's rope-scaling helpers for non-default types. + from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS + + rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type", "default")) + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] + inv_freq, self.attention_scaling = rope_init_fn(config, device=None) + else: + dim = config.hidden_size // config.num_attention_heads + inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.attention_scaling = 1.0 + self.register_buffer("inv_freq", inv_freq, persistent=False) @torch.no_grad() @@ -93,8 +114,8 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) From 15ab18e26c21ad19ec28dc709fb3407a8f706e61 Mon Sep 17 00:00:00 2001 From: princepride Date: Mon, 9 Feb 2026 07:22:39 -0800 Subject: [PATCH 4/6] fix some bug Signed-off-by: princepride --- vllm_omni/diffusion/models/bagel/bagel_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index 3d454a5143..efd7391a9c 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -100,7 +100,7 @@ def __init__(self, config): self.register_buffer("inv_freq", inv_freq, persistent=False) @torch.no_grad() - def forward(self, x, position_ids): + def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Generate cos/sin embeddings for given position ids. Args: From 47323537ac1cde7858041c63a8726e740845d561 Mon Sep 17 00:00:00 2001 From: princepride Date: Mon, 9 Feb 2026 17:38:39 -0800 Subject: [PATCH 5/6] add docs of bagel tp Signed-off-by: princepride --- examples/offline_inference/bagel/README.md | 18 ++++++++++++++++ examples/online_serving/bagel/README.md | 21 +++++++++++++++++-- .../models/bagel/bagel_transformer.py | 1 + 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference/bagel/README.md b/examples/offline_inference/bagel/README.md index 7bcfb42fa1..c1850c27ae 100644 --- a/examples/offline_inference/bagel/README.md +++ b/examples/offline_inference/bagel/README.md @@ -151,6 +151,24 @@ The default yaml configuration deploys Thinker and DiT on the same GPU. You can ------ +#### Tensor Parallelism (TP) + +For larger models or multi-GPU environments, you can enable Tensor Parallelism (TP) by modifying the stage configuration (e.g., [`bagel.yaml`](../../../vllm_omni/model_executor/stage_configs/bagel.yaml)). + +1. **Set `tensor_parallel_size`**: Increase this value (e.g., to `2` or `4`). +2. **Set `devices`**: Specify the comma-separated GPU IDs to be used for the stage (e.g., `"0,1"`). + +Example configuration for TP=2 on GPUs 0 and 1: +```yaml + engine_args: + tensor_parallel_size: 2 + ... + runtime: + devices: "0,1" +``` + +------ + #### 🔗 Runtime Configuration | Parameter | Value | Description | diff --git a/examples/online_serving/bagel/README.md b/examples/online_serving/bagel/README.md index 3fbea0550b..9308acb777 100644 --- a/examples/online_serving/bagel/README.md +++ b/examples/online_serving/bagel/README.md @@ -22,12 +22,29 @@ cd /workspace/vllm-omni/examples/online_serving/bagel bash run_server.sh ``` -If you have a custom stage configs file, launch the server with the command below: - ```bash vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 --stage-configs-path /path/to/stage_configs_file ``` +#### 🚀 Tensor Parallelism (TP) + +For larger models or multi-GPU environments, you can enable Tensor Parallelism (TP) for the server. + +1. **Modify Stage Config**: Create or modify a stage configuration yaml (e.g., [`bagel.yaml`](../../../vllm_omni/model_executor/stage_configs/bagel.yaml)). Set `tensor_parallel_size` to `2` (or more) and update `devices` to include multiple GPU IDs (e.g., `"0,1"`). + +```yaml + engine_args: + tensor_parallel_size: 2 + ... + runtime: + devices: "0,1" +``` + +2. **Launch Server**: +```bash +vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 --stage-configs-path /path/to/your/custom_bagel.yaml +``` + ### Send Multi-modal Request Get into the bagel folder: diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index efd7391a9c..6d1c1b10e9 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -142,6 +142,7 @@ def __init__( self.down_proj = RowParallelLinear( intermediate_size, hidden_size, + input_is_parallel=True, bias=False, ) if hidden_act != "silu": From 71bb46bb11286537857c68e4051dd9757fdb0e8e Mon Sep 17 00:00:00 2001 From: princepride Date: Tue, 10 Feb 2026 00:54:56 -0800 Subject: [PATCH 6/6] commit add tp for mlp and attention part Signed-off-by: princepride --- .../models/bagel/bagel_transformer.py | 190 ++++++++++++------ .../diffusion/models/bagel/pipeline_bagel.py | 20 ++ 2 files changed, 154 insertions(+), 56 deletions(-) diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index 6d1c1b10e9..839fc1d727 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -18,13 +18,14 @@ from torch.nn.attention.flex_attention import flex_attention from transformers.models.qwen2.configuration_qwen2 import Qwen2Config from transformers.models.qwen2.modeling_qwen2 import ( - Qwen2Attention, Qwen2PreTrainedModel, ) from transformers.utils import ModelOutput +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, + QKVParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding @@ -58,17 +59,19 @@ def patchify(imgs, p): class MLPconnector(nn.Module): def __init__(self, input_dim, output_dim, activation="gelu_pytorch_tanh"): super().__init__() - self.fc1 = nn.Linear(input_dim, output_dim) + self.fc1 = ColumnParallelLinear(input_dim, output_dim, bias=True, gather_output=False) if activation == "gelu": self.act = nn.GELU() elif activation == "gelu_pytorch_tanh": self.act = nn.GELU(approximate="tanh") else: self.act = nn.ReLU() - self.fc2 = nn.Linear(output_dim, output_dim) + self.fc2 = RowParallelLinear(output_dim, output_dim, bias=True, input_is_parallel=True) def forward(self, x): - return self.fc2(self.act(self.fc1(x))) + x_parallel, _ = self.fc1(x) + x_parallel = self.act(x_parallel) + return self.fc2(x_parallel)[0] class BagelRotaryEmbedding(nn.Module): @@ -248,25 +251,70 @@ class BaseNavitOutputWithPast(ModelOutput): past_key_values: NaiveCache | None = None -class PackedAttentionMoT(Qwen2Attention): +class PackedAttentionMoT(nn.Module): + """Packed attention with Mixture-of-Tokens routing for understanding/generation. + + Uses vLLM's QKVParallelLinear and RowParallelLinear for tensor parallelism + support, following the same pattern as vLLM's Qwen2Attention. + + The q/k/v projections are stacked into a single QKVParallelLinear: + - qkv_proj : stacks q_proj + k_proj + v_proj (understanding + gen text) + - qkv_proj_moe_gen : stacks q_proj_moe_gen + k_proj_moe_gen + v_proj_moe_gen (gen vae) + """ + def __init__(self, config, layer_idx: int | None = None): - super().__init__(config, layer_idx) + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + # Understanding mode projections (stacked q/k/v) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + ) + + # Generation mode MoE projections (stacked q/k/v) + self.qkv_proj_moe_gen = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + ) + self.o_proj_moe_gen = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + ) + + # QK normalization self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.q_norm_moe_gen = RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm_moe_gen = RMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.num_key_value_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_attention_heads - - head_dim = self.head_dim - self.q_proj_moe_gen = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj_moe_gen = nn.Linear(config.hidden_size, config.num_key_value_heads * head_dim, bias=True) - self.v_proj_moe_gen = nn.Linear(config.hidden_size, config.num_key_value_heads * head_dim, bias=True) - self.o_proj_moe_gen = nn.Linear(config.num_attention_heads * head_dim, config.hidden_size, bias=False) - self.rotary_op = RotaryEmbedding(is_neox_style=True) def forward( @@ -285,38 +333,43 @@ def forward( packed_text_indexes=None, ): if mode == "und": - packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim) - packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) - packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) + qkv, _ = self.qkv_proj(packed_query_sequence) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + packed_query_states = q.view(-1, self.num_heads, self.head_dim) + packed_key_states = k.view(-1, self.num_kv_heads, self.head_dim) + packed_value_states = v.view(-1, self.num_kv_heads, self.head_dim) packed_query_states = self.q_norm(packed_query_states) packed_key_states = self.k_norm(packed_key_states) elif mode == "gen": packed_query_sequence = packed_query_sequence.to(torch.bfloat16) - packed_query_states = packed_query_sequence.new_zeros( - (packed_query_sequence.shape[0], self.num_heads * self.head_dim) - ) - packed_key_states = packed_query_sequence.new_zeros( - (packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim) - ) - packed_value_states = packed_query_sequence.new_zeros( - (packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim) - ) packed_text_query_sequence = packed_query_sequence[packed_text_indexes] packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] - packed_query_states[packed_text_indexes] = self.q_proj(packed_text_query_sequence) - packed_query_states[packed_vae_token_indexes] = self.q_proj_moe_gen(packed_vae_query_sequence) + # Project text tokens through base qkv + text_qkv, _ = self.qkv_proj(packed_text_query_sequence) + text_q, text_k, text_v = text_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Project vae tokens through moe_gen qkv + vae_qkv, _ = self.qkv_proj_moe_gen(packed_vae_query_sequence) + vae_q, vae_k, vae_v = vae_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - packed_key_states[packed_text_indexes] = self.k_proj(packed_text_query_sequence) - packed_key_states[packed_vae_token_indexes] = self.k_proj_moe_gen(packed_vae_query_sequence) + # Merge into packed tensors + total_len = packed_query_sequence.shape[0] + packed_query_states = packed_query_sequence.new_zeros((total_len, self.q_size)) + packed_key_states = packed_query_sequence.new_zeros((total_len, self.kv_size)) + packed_value_states = packed_query_sequence.new_zeros((total_len, self.kv_size)) - packed_value_states[packed_text_indexes] = self.v_proj(packed_text_query_sequence) - packed_value_states[packed_vae_token_indexes] = self.v_proj_moe_gen(packed_vae_query_sequence) + packed_query_states[packed_text_indexes] = text_q + packed_query_states[packed_vae_token_indexes] = vae_q + packed_key_states[packed_text_indexes] = text_k + packed_key_states[packed_vae_token_indexes] = vae_k + packed_value_states[packed_text_indexes] = text_v + packed_value_states[packed_vae_token_indexes] = vae_v packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim) - packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim) - packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim) + packed_key_states = packed_key_states.view(-1, self.num_kv_heads, self.head_dim) + packed_value_states = packed_value_states.view(-1, self.num_kv_heads, self.head_dim) packed_query_states = packed_query_states.to(torch.float32) packed_query_states[packed_text_indexes] = self.q_norm(packed_query_states[packed_text_indexes]) @@ -343,8 +396,8 @@ def forward( past_value_states = past_key_values.value_cache[self.layer_idx] seqlens = sum(query_lens) + sum(key_values_lens) - merged_key_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim]) - merged_value_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim]) + merged_key_states = past_key_states.new_zeros(size=[seqlens, self.num_kv_heads, self.head_dim]) + merged_value_states = past_key_states.new_zeros(size=[seqlens, self.num_kv_heads, self.head_dim]) merged_key_states[packed_query_indexes] = packed_key_states merged_key_states[packed_key_value_indexes] = past_key_states merged_value_states[packed_query_indexes] = packed_value_states @@ -368,14 +421,16 @@ def forward( max_seqlen_k=max(key_values_lens).item(), causal=is_causal, ) - packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size) + packed_attn_output = packed_attn_output.reshape(-1, self.q_size) if mode == "und": - packed_attn_output = self.o_proj(packed_attn_output) + packed_attn_output, _ = self.o_proj(packed_attn_output) elif mode == "gen": - packed_attn_output[packed_text_indexes] = self.o_proj(packed_attn_output[packed_text_indexes]) - packed_attn_output[packed_vae_token_indexes] = self.o_proj_moe_gen( - packed_attn_output[packed_vae_token_indexes] - ) + text_out, _ = self.o_proj(packed_attn_output[packed_text_indexes]) + vae_out, _ = self.o_proj_moe_gen(packed_attn_output[packed_vae_token_indexes]) + full_output = text_out.new_zeros((packed_attn_output.shape[0], self.hidden_size)) + full_output[packed_text_indexes] = text_out + full_output[packed_vae_token_indexes] = vae_out + packed_attn_output = full_output if update_past_key_values: past_key_values.key_cache[self.layer_idx] = merged_key_states @@ -389,7 +444,7 @@ def __init__( self, config, layer_idx: int | None = None, - attn_module: Qwen2Attention | None = PackedAttentionMoT, + attn_module: type[nn.Module] | None = PackedAttentionMoT, ): super().__init__() self.hidden_size = config.hidden_size @@ -622,20 +677,43 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights for vLLM parallel layers. - Parameter names (gate_proj, up_proj, down_proj, etc.) are kept - identical to the HF checkpoint, so no name remapping is needed. - vLLM parallel layers attach a ``weight_loader`` to each parameter - that handles TP sharding automatically. + Handles stacked parameter remapping for QKVParallelLinear: + - q_proj, k_proj, v_proj -> qkv_proj (shard ids: q, k, v) + - q_proj_moe_gen, k_proj_moe_gen, v_proj_moe_gen -> qkv_proj_moe_gen + Other parallel layers (gate_proj, up_proj, down_proj, embed_tokens, etc.) + keep HF checkpoint names and use weight_loader for TP sharding. """ + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + # More specific _moe_gen patterns FIRST to avoid substring + # ambiguity (`.q_proj` is a substring of `.q_proj_moe_gen`). + (".qkv_proj_moe_gen", ".q_proj_moe_gen", "q"), + (".qkv_proj_moe_gen", ".k_proj_moe_gen", "k"), + (".qkv_proj_moe_gen", ".v_proj_moe_gen", "v"), + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - param = params_dict.get(name) - if param is None: - continue - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict.get(name) + if param is None: + break + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict.get(name) + if param is None: + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -751,7 +829,7 @@ def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_pat return pos_ids -class Bagel(torch.nn.Module): +class Bagel(nn.Module): config_class = BagelConfig base_model_prefix = "bagel" diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index b6955390f6..040f7ddecc 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -485,6 +485,26 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: tp_aware_params = {name for name, p in self.named_parameters() if hasattr(p, "weight_loader")} + # Expand allowed/tp_aware_params with stacked param source names. + # QKVParallelLinear merges q_proj+k_proj+v_proj into qkv_proj; the + # checkpoint stores the original separate names. We must recognise + # those names so _filtered_weights does not drop them. + _stacked_expansions = [ + (".qkv_proj", ".q_proj"), + (".qkv_proj", ".k_proj"), + (".qkv_proj", ".v_proj"), + (".qkv_proj_moe_gen", ".q_proj_moe_gen"), + (".qkv_proj_moe_gen", ".k_proj_moe_gen"), + (".qkv_proj_moe_gen", ".v_proj_moe_gen"), + ] + stacked_source_names: set[str] = set() + for name in list(allowed): + for target_suffix, source_suffix in _stacked_expansions: + if target_suffix in name: + stacked_source_names.add(name.replace(target_suffix, source_suffix)) + allowed.update(stacked_source_names) + tp_aware_params.update(stacked_source_names) + def _normalize_name(name: str) -> str: # Common wrappers/prefixes in checkpoints. for pfx in ("module.", "model."):