diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 34281b2e99ee..263f4c8379cf 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -25,6 +25,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import LlamaMLP as JambaMLP from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -33,7 +34,7 @@ from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsV0Only) -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -87,23 +88,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states.view(orig_shape) -class JambaMLP(JambaMoE): - - def __init__(self, - config: JambaConfig, - params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): - super().__init__(config, - num_experts=1, - top_k=1, - params_dtype=params_dtype, - tp_size=tp_size, - quant_config=quant_config, - prefix=prefix) - - class JambaMambaDecoderLayer(nn.Module): def __init__(self, @@ -132,10 +116,20 @@ def __init__(self, ) num_experts = config.layers_num_experts[layer_idx] - ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP - self.feed_forward = ffn_layer_class(config, - quant_config=quant_config, - prefix=f"{prefix}.feed_forward") + if num_experts > 1: + self.feed_forward = JambaMoE( + config, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + else: + self.feed_forward = JambaMLP( + config.hidden_size, + config.intermediate_size, + config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_ff_layernorm = RMSNorm(config.hidden_size, @@ -216,10 +210,20 @@ def __init__(self, ) num_experts = config.layers_num_experts[layer_idx] - ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP - self.feed_forward = ffn_layer_class(config, - quant_config=quant_config, - prefix=f"{prefix}.feed_forward") + if num_experts > 1: + self.feed_forward = JambaMoE( + config, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + else: + self.feed_forward = JambaMLP( + config.hidden_size, + config.intermediate_size, + config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_ff_layernorm = RMSNorm(config.hidden_size, @@ -359,15 +363,97 @@ def forward( hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if 'experts' in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for ( + param_name, + weight_name, + expert_id, + shard_id, + ) in expert_params_mapping: + if weight_name not in name: + continue + + if is_pp_missing_parameter(name, self): + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsV0Only): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={ + ".self_attn.": ".", + ".A_log": ".A" + }, ) packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], + "gate_up_proj": ["gate_proj", "up_proj"], "in_proj": ["in_proj"], } @@ -468,96 +554,11 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - - if "A_log" in name: - name = name.replace("A_log", "A") - - if ".self_attn." in name: - name = name.replace(".self_attn", "") - - if "feed_forward" in name and not _is_moe_layer(name): - ## map MLP layers to expert with ID=0 - name = name.replace("feed_forward", "feed_forward.experts.0") - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - if 'experts' in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - - if name.endswith(".bias") and name not in params_dict: - continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for ( - param_name, - weight_name, - expert_id, - shard_id, - ) in expert_params_mapping: - if weight_name not in name: - continue - - if is_pp_missing_parameter(name, self): - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) -def _is_moe_layer(name: str): - return any( - [experts_name in name for experts_name in [ - "experts", - "router", - ]]) + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() class JambaForSequenceClassification(JambaForCausalLM):