@@ -206,6 +206,42 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
206206 return final_hidden_states .view (orig_shape )
207207
208208
209+ def get_head_dim (config ):
210+ if hasattr (config , "head_dim" ):
211+ return int (config .head_dim )
212+ if hasattr (config , "attention_head_dim" ):
213+ return int (config .attention_head_dim )
214+
215+ # since some hunyuan model don't follow the self.hidden_size // self.total_num_heads rule
216+ # wrong setting may cause runtime error, just throw error if this field is missing.
217+ raise ValueError ("Missing head dim config, try set head_dim in config.json" )
218+
219+
220+ def check_head_dim (config ):
221+ # Some models may lack `head_dim` and use `attention_head_dim` instead.
222+ # This attribute is also used by flashinfer_backend.py, so we check for
223+ # consistency and raise an error if it's not met to avoid silent failures.
224+ # Although we could adapt the HunYuan model to use `attention_head_dim`,
225+ # flashinfer expects `head_dim`, so we enforce its presence for correctness.
226+ calc_head_dim = config .hidden_size // config .num_attention_heads
227+
228+ if hasattr (config , "attention_head_dim" ):
229+ if calc_head_dim != config .attention_head_dim and not hasattr (
230+ config , "head_dim"
231+ ):
232+ # in this case, flash infer(and other components may calculate wrong value.)
233+ raise ValueError (
234+ f"HunYuan model config error: calculated head_dim { calc_head_dim } != attention_head_dim { config .attention_head_dim } "
235+ + f"\n Please Add head_dim:{ config .attention_head_dim } in config.json to make sure correctly inference."
236+ )
237+
238+ if hasattr (config , "head_dim" ) and config .attention_head_dim != config .head_dim :
239+ raise ValueError (
240+ f"HunYuan model config error: head_dim({ config .head_dim } ) != attention_head_dim({ config .attention_head_dim } )"
241+ + f"\n Please change head_dim:{ config .attention_head_dim } in config.json to make sure correctly inference."
242+ )
243+
244+
209245class HunYuanAttention (nn .Module ):
210246
211247 def __init__ (
@@ -240,9 +276,11 @@ def __init__(
240276 assert tp_size % self .total_num_kv_heads == 0
241277 self .num_kv_heads = max (1 , self .total_num_kv_heads // tp_size )
242278 # MistralConfig has an optional head_dim introduced by Mistral-Nemo
243- self .head_dim = getattr (
244- config , "head_dim" , self .hidden_size // self .total_num_heads
245- )
279+ # Prioritize `head_dim` but fall back to `attention_head_dim` for Hunyuan models.
280+ self .head_dim = get_head_dim (config )
281+
282+ check_head_dim (config )
283+
246284 self .q_size = self .num_heads * self .head_dim
247285 self .kv_size = self .num_kv_heads * self .head_dim
248286 self .scaling = self .head_dim ** - 0.5
@@ -493,7 +531,6 @@ def forward(
493531 hidden_states = self .get_input_embeddings (input_ids )
494532 residual = None
495533
496- cla_factor = _get_cla_factor (self .config )
497534 prev_kv_states = None
498535 for i in range (len (self .layers )):
499536 layer = self .layers [i ]
@@ -560,6 +597,11 @@ def __init__(
560597 if config .tie_word_embeddings :
561598 self .lm_head .weight = self .model .embed_tokens .weight
562599
600+ self .hidden_size = config .hidden_size
601+ self .head_dim = get_head_dim (config )
602+
603+ check_head_dim (config )
604+
563605 logit_scale = getattr (config , "logit_scale" , 1.0 )
564606 self .logits_processor = LogitsProcessor (config , logit_scale = logit_scale )
565607 self .sampler = Sampler ()
@@ -582,16 +624,14 @@ def _split_qkv_weight(self, qkv: torch.Tensor):
582624 self .config , "num_key_value_heads" , self .config .num_attention_heads
583625 )
584626 num_key_value_groups = num_attention_heads // num_kv_heads
585- hidden_size = self .config .hidden_size
586- attention_head_dim = self .config .hidden_size // num_attention_heads
587627
588628 qkv = qkv .reshape (
589- num_kv_heads , num_key_value_groups + 2 , attention_head_dim , hidden_size
629+ num_kv_heads , num_key_value_groups + 2 , self . head_dim , self . hidden_size
590630 )
591631 q , k , v = torch .split (qkv , (num_key_value_groups , 1 , 1 ), dim = 1 )
592- q = q .reshape (- 1 , hidden_size )
593- k = k .reshape (- 1 , hidden_size )
594- v = v .reshape (- 1 , hidden_size )
632+ q = q .reshape (- 1 , self . hidden_size )
633+ k = k .reshape (- 1 , self . hidden_size )
634+ v = v .reshape (- 1 , self . hidden_size )
595635 return torch .concat ((q , k , v ))
596636 # return qkv.reshape((num_kv_heads, num_key_value_groups+2 , attention_head_dim, hidden_size)).permute((1,0,2,3)).reshape((-1, hidden_size)),
597637
@@ -768,4 +808,8 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
768808 )
769809
770810
771- EntryClass = HunYuanMoEV1ForCausalLM
811+ class HunYuanDenseV1ForCausalLM (HunYuanMoEV1ForCausalLM ):
812+ pass
813+
814+
815+ EntryClass = [HunYuanMoEV1ForCausalLM , HunYuanDenseV1ForCausalLM ]
0 commit comments