diff --git a/examples/offline_inference/bagel/README.md b/examples/offline_inference/bagel/README.md index 7bcfb42fa1..c1850c27ae 100644 --- a/examples/offline_inference/bagel/README.md +++ b/examples/offline_inference/bagel/README.md @@ -151,6 +151,24 @@ The default yaml configuration deploys Thinker and DiT on the same GPU. You can ------ +#### Tensor Parallelism (TP) + +For larger models or multi-GPU environments, you can enable Tensor Parallelism (TP) by modifying the stage configuration (e.g., [`bagel.yaml`](../../../vllm_omni/model_executor/stage_configs/bagel.yaml)). + +1. **Set `tensor_parallel_size`**: Increase this value (e.g., to `2` or `4`). +2. **Set `devices`**: Specify the comma-separated GPU IDs to be used for the stage (e.g., `"0,1"`). + +Example configuration for TP=2 on GPUs 0 and 1: +```yaml + engine_args: + tensor_parallel_size: 2 + ... + runtime: + devices: "0,1" +``` + +------ + #### 🔗 Runtime Configuration | Parameter | Value | Description | diff --git a/examples/online_serving/bagel/README.md b/examples/online_serving/bagel/README.md index 3fbea0550b..9308acb777 100644 --- a/examples/online_serving/bagel/README.md +++ b/examples/online_serving/bagel/README.md @@ -22,12 +22,29 @@ cd /workspace/vllm-omni/examples/online_serving/bagel bash run_server.sh ``` -If you have a custom stage configs file, launch the server with the command below: - ```bash vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 --stage-configs-path /path/to/stage_configs_file ``` +#### 🚀 Tensor Parallelism (TP) + +For larger models or multi-GPU environments, you can enable Tensor Parallelism (TP) for the server. + +1. **Modify Stage Config**: Create or modify a stage configuration yaml (e.g., [`bagel.yaml`](../../../vllm_omni/model_executor/stage_configs/bagel.yaml)). Set `tensor_parallel_size` to `2` (or more) and update `devices` to include multiple GPU IDs (e.g., `"0,1"`). + +```yaml + engine_args: + tensor_parallel_size: 2 + ... + runtime: + devices: "0,1" +``` + +2. **Launch Server**: +```bash +vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 --stage-configs-path /path/to/your/custom_bagel.yaml +``` + ### Send Multi-modal Request Get into the bagel folder: diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index 6ee81a4fd4..839fc1d727 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -8,6 +8,7 @@ # available at https://github.com/huggingface/transformers/blob/main/LICENSE. import math +from collections.abc import Iterable from dataclasses import dataclass from typing import Any @@ -17,13 +18,18 @@ from torch.nn.attention.flex_attention import flex_attention from transformers.models.qwen2.configuration_qwen2 import Qwen2Config from transformers.models.qwen2.modeling_qwen2 import ( - Qwen2Attention, - Qwen2MLP, Qwen2PreTrainedModel, - Qwen2RMSNorm, - Qwen2RotaryEmbedding, ) from transformers.utils import ModelOutput +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.transformers_utils.configs.bagel import BagelConfig from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -53,17 +59,105 @@ def patchify(imgs, p): class MLPconnector(nn.Module): def __init__(self, input_dim, output_dim, activation="gelu_pytorch_tanh"): super().__init__() - self.fc1 = nn.Linear(input_dim, output_dim) + self.fc1 = ColumnParallelLinear(input_dim, output_dim, bias=True, gather_output=False) if activation == "gelu": self.act = nn.GELU() elif activation == "gelu_pytorch_tanh": self.act = nn.GELU(approximate="tanh") else: self.act = nn.ReLU() - self.fc2 = nn.Linear(output_dim, output_dim) + self.fc2 = RowParallelLinear(output_dim, output_dim, bias=True, input_is_parallel=True) def forward(self, x): - return self.fc2(self.act(self.fc1(x))) + x_parallel, _ = self.fc1(x) + x_parallel = self.act(x_parallel) + return self.fc2(x_parallel)[0] + + +class BagelRotaryEmbedding(nn.Module): + """Standalone rotary embedding that generates cos/sin from position ids. + + Replaces HuggingFace's Qwen2RotaryEmbedding while preserving full + ``rope_scaling`` support. When ``config.rope_scaling`` is set (e.g. + linear, dynamic-NTK, YaRN, …), we delegate the ``inv_freq`` / + ``attention_scaling`` computation to HF's ``ROPE_INIT_FUNCTIONS`` so + that the frequency basis and scaling factor are identical to the + original checkpoint. This module has no learnable parameters. + """ + + def __init__(self, config): + super().__init__() + + if config.rope_scaling is not None: + # Delegate to HF's rope-scaling helpers for non-default types. + from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS + + rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type", "default")) + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] + inv_freq, self.attention_scaling = rope_init_fn(config, device=None) + else: + dim = config.hidden_size // config.num_attention_heads + inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.attention_scaling = 1.0 + + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Generate cos/sin embeddings for given position ids. + + Args: + x: Input tensor (only used for dtype inference). + position_ids: Position indices, shape (batch_size, seq_len). + + Returns: + cos, sin: Rotary embeddings, each of shape (batch_size, seq_len, dim). + """ + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class BagelMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str = "silu", + ) -> None: + super().__init__() + self.gate_proj = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=False, + gather_output=False, + ) + self.up_proj = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=False, + gather_output=False, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + input_is_parallel=True, + bias=False, + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. Only silu is supported.") + self.act_fn = nn.SiLU() + + def forward(self, x): + gate, _ = self.gate_proj(x) + up, _ = self.up_proj(x) + x = self.act_fn(gate) * up + x, _ = self.down_proj(x) + return x torch._dynamo.config.cache_size_limit = 512 @@ -157,24 +251,69 @@ class BaseNavitOutputWithPast(ModelOutput): past_key_values: NaiveCache | None = None -class PackedAttentionMoT(Qwen2Attention): - def __init__(self, config, layer_idx: int | None = None): - super().__init__(config, layer_idx) - self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.q_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) +class PackedAttentionMoT(nn.Module): + """Packed attention with Mixture-of-Tokens routing for understanding/generation. + + Uses vLLM's QKVParallelLinear and RowParallelLinear for tensor parallelism + support, following the same pattern as vLLM's Qwen2Attention. + + The q/k/v projections are stacked into a single QKVParallelLinear: + - qkv_proj : stacks q_proj + k_proj + v_proj (understanding + gen text) + - qkv_proj_moe_gen : stacks q_proj_moe_gen + k_proj_moe_gen + v_proj_moe_gen (gen vae) + """ + def __init__(self, config, layer_idx: int | None = None): + super().__init__() + self.layer_idx = layer_idx self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.num_key_value_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_attention_heads - head_dim = self.head_dim - self.q_proj_moe_gen = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj_moe_gen = nn.Linear(config.hidden_size, config.num_key_value_heads * head_dim, bias=True) - self.v_proj_moe_gen = nn.Linear(config.hidden_size, config.num_key_value_heads * head_dim, bias=True) - self.o_proj_moe_gen = nn.Linear(config.num_attention_heads * head_dim, config.hidden_size, bias=False) + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + # Understanding mode projections (stacked q/k/v) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + ) + + # Generation mode MoE projections (stacked q/k/v) + self.qkv_proj_moe_gen = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + ) + self.o_proj_moe_gen = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + ) + + # QK normalization + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.q_norm_moe_gen = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm_moe_gen = RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.rotary_op = RotaryEmbedding(is_neox_style=True) @@ -194,38 +333,43 @@ def forward( packed_text_indexes=None, ): if mode == "und": - packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim) - packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) - packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) + qkv, _ = self.qkv_proj(packed_query_sequence) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + packed_query_states = q.view(-1, self.num_heads, self.head_dim) + packed_key_states = k.view(-1, self.num_kv_heads, self.head_dim) + packed_value_states = v.view(-1, self.num_kv_heads, self.head_dim) packed_query_states = self.q_norm(packed_query_states) packed_key_states = self.k_norm(packed_key_states) elif mode == "gen": packed_query_sequence = packed_query_sequence.to(torch.bfloat16) - packed_query_states = packed_query_sequence.new_zeros( - (packed_query_sequence.shape[0], self.num_heads * self.head_dim) - ) - packed_key_states = packed_query_sequence.new_zeros( - (packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim) - ) - packed_value_states = packed_query_sequence.new_zeros( - (packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim) - ) packed_text_query_sequence = packed_query_sequence[packed_text_indexes] packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] - packed_query_states[packed_text_indexes] = self.q_proj(packed_text_query_sequence) - packed_query_states[packed_vae_token_indexes] = self.q_proj_moe_gen(packed_vae_query_sequence) + # Project text tokens through base qkv + text_qkv, _ = self.qkv_proj(packed_text_query_sequence) + text_q, text_k, text_v = text_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - packed_key_states[packed_text_indexes] = self.k_proj(packed_text_query_sequence) - packed_key_states[packed_vae_token_indexes] = self.k_proj_moe_gen(packed_vae_query_sequence) + # Project vae tokens through moe_gen qkv + vae_qkv, _ = self.qkv_proj_moe_gen(packed_vae_query_sequence) + vae_q, vae_k, vae_v = vae_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - packed_value_states[packed_text_indexes] = self.v_proj(packed_text_query_sequence) - packed_value_states[packed_vae_token_indexes] = self.v_proj_moe_gen(packed_vae_query_sequence) + # Merge into packed tensors + total_len = packed_query_sequence.shape[0] + packed_query_states = packed_query_sequence.new_zeros((total_len, self.q_size)) + packed_key_states = packed_query_sequence.new_zeros((total_len, self.kv_size)) + packed_value_states = packed_query_sequence.new_zeros((total_len, self.kv_size)) + + packed_query_states[packed_text_indexes] = text_q + packed_query_states[packed_vae_token_indexes] = vae_q + packed_key_states[packed_text_indexes] = text_k + packed_key_states[packed_vae_token_indexes] = vae_k + packed_value_states[packed_text_indexes] = text_v + packed_value_states[packed_vae_token_indexes] = vae_v packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim) - packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim) - packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim) + packed_key_states = packed_key_states.view(-1, self.num_kv_heads, self.head_dim) + packed_value_states = packed_value_states.view(-1, self.num_kv_heads, self.head_dim) packed_query_states = packed_query_states.to(torch.float32) packed_query_states[packed_text_indexes] = self.q_norm(packed_query_states[packed_text_indexes]) @@ -252,8 +396,8 @@ def forward( past_value_states = past_key_values.value_cache[self.layer_idx] seqlens = sum(query_lens) + sum(key_values_lens) - merged_key_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim]) - merged_value_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim]) + merged_key_states = past_key_states.new_zeros(size=[seqlens, self.num_kv_heads, self.head_dim]) + merged_value_states = past_key_states.new_zeros(size=[seqlens, self.num_kv_heads, self.head_dim]) merged_key_states[packed_query_indexes] = packed_key_states merged_key_states[packed_key_value_indexes] = past_key_states merged_value_states[packed_query_indexes] = packed_value_states @@ -277,14 +421,16 @@ def forward( max_seqlen_k=max(key_values_lens).item(), causal=is_causal, ) - packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size) + packed_attn_output = packed_attn_output.reshape(-1, self.q_size) if mode == "und": - packed_attn_output = self.o_proj(packed_attn_output) + packed_attn_output, _ = self.o_proj(packed_attn_output) elif mode == "gen": - packed_attn_output[packed_text_indexes] = self.o_proj(packed_attn_output[packed_text_indexes]) - packed_attn_output[packed_vae_token_indexes] = self.o_proj_moe_gen( - packed_attn_output[packed_vae_token_indexes] - ) + text_out, _ = self.o_proj(packed_attn_output[packed_text_indexes]) + vae_out, _ = self.o_proj_moe_gen(packed_attn_output[packed_vae_token_indexes]) + full_output = text_out.new_zeros((packed_attn_output.shape[0], self.hidden_size)) + full_output[packed_text_indexes] = text_out + full_output[packed_vae_token_indexes] = vae_out + packed_attn_output = full_output if update_past_key_values: past_key_values.key_cache[self.layer_idx] = merged_key_states @@ -298,19 +444,19 @@ def __init__( self, config, layer_idx: int | None = None, - attn_module: Qwen2Attention | None = PackedAttentionMoT, + attn_module: type[nn.Module] | None = PackedAttentionMoT, ): super().__init__() self.hidden_size = config.hidden_size self.self_attn = attn_module(config, layer_idx) - self.mlp = Qwen2MLP(config) - self.mlp_moe_gen = Qwen2MLP(config) - self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.input_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = BagelMLP(config.hidden_size, config.intermediate_size, config.hidden_act) + self.mlp_moe_gen = BagelMLP(config.hidden_size, config.intermediate_size, config.hidden_act) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm_moe_gen = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm_moe_gen = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -391,7 +537,7 @@ def __init__(self, config): self.vocab_size = config.vocab_size self.use_moe = "Mo" in config.layer_module - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList( [ Qwen2MoTDecoderLayer(config, layer_idx, attn_module=PackedAttentionMoT) @@ -399,10 +545,10 @@ def __init__(self, config): ] ) - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.use_moe: - self.norm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen2RotaryEmbedding(config=config) + self.norm_moe_gen = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = BagelRotaryEmbedding(config=config) # Initialize weights and apply final processing self.post_init() @@ -528,6 +674,50 @@ def forward( return outputs + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights for vLLM parallel layers. + + Handles stacked parameter remapping for QKVParallelLinear: + - q_proj, k_proj, v_proj -> qkv_proj (shard ids: q, k, v) + - q_proj_moe_gen, k_proj_moe_gen, v_proj_moe_gen -> qkv_proj_moe_gen + Other parallel layers (gate_proj, up_proj, down_proj, embed_tokens, etc.) + keep HF checkpoint names and use weight_loader for TP sharding. + """ + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + # More specific _moe_gen patterns FIRST to avoid substring + # ambiguity (`.q_proj` is a substring of `.q_proj_moe_gen`). + (".qkv_proj_moe_gen", ".q_proj_moe_gen", "q"), + (".qkv_proj_moe_gen", ".k_proj_moe_gen", "k"), + (".qkv_proj_moe_gen", ".v_proj_moe_gen", "v"), + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict.get(name) + if param is None: + break + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict.get(name) + if param is None: + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): grid_h = np.arange(grid_size, dtype=np.float32) @@ -639,7 +829,7 @@ def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_pat return pos_ids -class Bagel(torch.nn.Module): +class Bagel(nn.Module): config_class = BagelConfig base_model_prefix = "bagel" diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index bdb9f1f5c3..040f7ddecc 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -402,7 +402,7 @@ def vae_transforms(img): ) # Fail fast with a clear error instead of CUDA gather OOB. max_tid = int(generation_input["packed_text_ids"].max().item()) - emb_n = int(self.language_model.model.embed_tokens.weight.shape[0]) + emb_n = int(self.language_model.vocab_size) if max_tid >= emb_n: raise ValueError( "Tokenizer/model vocab mismatch: max token id " @@ -438,7 +438,7 @@ def vae_transforms(img): ) # Fail fast for special tokens used by the image path as well. max_tid_img = int(generation_input["packed_text_ids"].max().item()) - emb_n = int(self.language_model.model.embed_tokens.weight.shape[0]) + emb_n = int(self.language_model.vocab_size) if max_tid_img >= emb_n: raise ValueError( "Tokenizer/model vocab mismatch (image path): max token id " @@ -483,6 +483,28 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: allowed = set(state.keys()) shapes = {k: tuple(v.shape) for k, v in state.items()} + tp_aware_params = {name for name, p in self.named_parameters() if hasattr(p, "weight_loader")} + + # Expand allowed/tp_aware_params with stacked param source names. + # QKVParallelLinear merges q_proj+k_proj+v_proj into qkv_proj; the + # checkpoint stores the original separate names. We must recognise + # those names so _filtered_weights does not drop them. + _stacked_expansions = [ + (".qkv_proj", ".q_proj"), + (".qkv_proj", ".k_proj"), + (".qkv_proj", ".v_proj"), + (".qkv_proj_moe_gen", ".q_proj_moe_gen"), + (".qkv_proj_moe_gen", ".k_proj_moe_gen"), + (".qkv_proj_moe_gen", ".v_proj_moe_gen"), + ] + stacked_source_names: set[str] = set() + for name in list(allowed): + for target_suffix, source_suffix in _stacked_expansions: + if target_suffix in name: + stacked_source_names.add(name.replace(target_suffix, source_suffix)) + allowed.update(stacked_source_names) + tp_aware_params.update(stacked_source_names) + def _normalize_name(name: str) -> str: # Common wrappers/prefixes in checkpoints. for pfx in ("module.", "model."): @@ -536,7 +558,7 @@ def _filtered_weights(): for cand in _iter_candidate_names(name): if cand in allowed: # Only accept if tensor shape matches target param/buffer shape. - if tuple(tensor.shape) == shapes.get(cand): + if tuple(tensor.shape) == shapes.get(cand) or cand in tp_aware_params: picked = cand break else: