11# coding=utf-8
22"""PyTorch MAMBA model."""
3- from dataclasses import dataclass
43from typing import Iterable , List , Optional , Tuple
54
65import torch
109from vllm .attention .backends .abstract import AttentionMetadata
1110from vllm .config import CacheConfig , LoRAConfig , SchedulerConfig
1211from vllm .distributed import get_tensor_model_parallel_world_size
13- from vllm .model_executor .layers .activation import SiluAndMul
1412from vllm .model_executor .layers .layernorm import RMSNorm
1513from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
1614 MergedColumnParallelLinear ,
3937KVCache = Tuple [torch .Tensor , torch .Tensor ]
4038
4139
42- @dataclass
43- class MambaCacheParams :
44- is_prompt : bool = False
45- conv_state : torch .Tensor = torch .Tensor ()
46- ssm_state : torch .Tensor = torch .Tensor ()
47-
48-
4940# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
5041class MambaMixer (nn .Module ):
5142 """
@@ -209,37 +200,6 @@ def forward(self, hidden_states: torch.Tensor,
209200 return contextualized_states
210201
211202
212- class MambaMLP (nn .Module ):
213-
214- def __init__ (
215- self ,
216- config : MambaConfig ,
217- quant_config : Optional [QuantizationConfig ] = None ,
218- ) -> None :
219- super ().__init__ ()
220- hidden_size = config .hidden_size
221- intermediate_size = config .intermediate_size
222- hidden_act = config .hidden_act
223- self .gate_up_proj = MergedColumnParallelLinear (
224- hidden_size , [intermediate_size ] * 2 ,
225- bias = False ,
226- quant_config = quant_config )
227- self .down_proj = RowParallelLinear (intermediate_size ,
228- hidden_size ,
229- bias = False ,
230- quant_config = quant_config )
231- if hidden_act != "silu" :
232- raise ValueError (f"Unsupported activation: { hidden_act } . "
233- "Only silu is supported for now." )
234- self .act_fn = SiluAndMul ()
235-
236- def forward (self , x ):
237- gate_up , _ = self .gate_up_proj (x )
238- x = self .act_fn (gate_up )
239- x , _ = self .down_proj (x )
240- return x
241-
242-
243203class MambaDecoderLayer (nn .Module ):
244204
245205 def __init__ (self ,
@@ -252,7 +212,6 @@ def __init__(self,
252212 self .config = config
253213 self .mixer = MambaMixer (config , layer_idx )
254214
255- self .feed_forward = MambaMLP (config , quant_config = quant_config )
256215 self .norm = RMSNorm (config .hidden_size , eps = config .layer_norm_epsilon )
257216 self .pre_ff_layernorm = RMSNorm (config .hidden_size ,
258217 eps = config .layer_norm_epsilon )
@@ -274,10 +233,6 @@ def forward(
274233
275234 hidden_states = self .mixer (hidden_states , attn_metadata , conv_state ,
276235 ssm_state )
277- # Fully Connected
278- hidden_states , residual = self .pre_ff_layernorm (
279- hidden_states , residual )
280- hidden_states = self .feed_forward (hidden_states )
281236 return hidden_states , residual
282237
283238
@@ -319,7 +274,6 @@ def forward(
319274 self ,
320275 input_ids : torch .Tensor ,
321276 positions : torch .Tensor ,
322- kv_caches : List [torch .Tensor ],
323277 attn_metadata : AttentionMetadata ,
324278 conv_state : torch .Tensor ,
325279 ssm_state : torch .Tensor ,
@@ -346,26 +300,6 @@ def forward(
346300
347301
348302class MambaForCausalLM (nn .Module , HasInnerState , IsAttentionFree ):
349- packed_modules_mapping = {
350- "qkv_proj" : [
351- "q_proj" ,
352- "k_proj" ,
353- "v_proj" ,
354- ],
355- }
356-
357- # LoRA specific attributes
358- supported_lora_modules = [
359- "qkv_proj" ,
360- "o_proj" ,
361- "embed_tokens" ,
362- "lm_head" ,
363- ]
364- embedding_modules = {
365- "embeddings" : "input_embeddings" ,
366- "lm_head" : "output_embeddings" ,
367- }
368- embedding_padding_modules = ["lm_head" ]
369303
370304 def __init__ (
371305 self ,
@@ -416,8 +350,8 @@ def forward(self,
416350 mamba_cache_tensors = self .mamba_cache .current_run_tensors (
417351 input_ids , attn_metadata , ** kwargs )
418352
419- hidden_states = self .backbone (input_ids , positions , kv_caches ,
420- attn_metadata , mamba_cache_tensors [0 ],
353+ hidden_states = self .backbone (input_ids , positions , attn_metadata ,
354+ mamba_cache_tensors [0 ],
421355 mamba_cache_tensors [1 ])
422356
423357 return hidden_states
@@ -457,43 +391,16 @@ def sample(
457391 return next_tokens
458392
459393 def load_weights (self , weights : Iterable [Tuple [str , torch .Tensor ]]):
460- stacked_params_mapping = [
461- # (param_name, shard_name, shard_id)
462- ("qkv_proj" , "q_proj" , "q" ),
463- ("qkv_proj" , "k_proj" , "k" ),
464- ("qkv_proj" , "v_proj" , "v" ),
465- ("gate_up_proj" , "gate_proj" , 0 ),
466- ("gate_up_proj" , "up_proj" , 1 ),
467- ]
468-
469394 params_dict = dict (self .named_parameters ())
470395 for name , loaded_weight in weights :
471- if "rotary_emb.inv_freq" in name :
472- continue
473-
474396 if "A_log" in name :
475397 name = name .replace ("A_log" , "A" )
476398
477- if ".self_attn." in name :
478- name = name .replace (".self_attn" , "" )
479-
480- for param_name , weight_name , shard_id in stacked_params_mapping :
481- if weight_name not in name :
482- continue
483- name = name .replace (weight_name , param_name )
484- # Skip loading extra bias for GPTQ models.
485- if name .endswith (".bias" ) and name not in params_dict :
486- continue
487- param = params_dict [name ]
488- weight_loader = param .weight_loader
489- weight_loader (param , loaded_weight , shard_id )
490- break
491- else :
492- # Skip loading extra bias for GPTQ models.
493- if name .endswith (".bias" ) and name not in params_dict :
494- continue
495-
496- param = params_dict [name ]
497- weight_loader = getattr (param , "weight_loader" ,
498- default_weight_loader )
499- weight_loader (param , loaded_weight )
399+ # Skip loading extra bias for GPTQ models.
400+ if name .endswith (".bias" ) and name not in params_dict :
401+ continue
402+
403+ param = params_dict [name ]
404+ weight_loader = getattr (param , "weight_loader" ,
405+ default_weight_loader )
406+ weight_loader (param , loaded_weight )
0 commit comments