Skip to content

Commit 4296021

Browse files
authored
[Hunyuan]: Fix Dense Model Support (#8117)
Signed-off-by: Asher Zhang <asherszhang@tencent.com>
1 parent 01857fa commit 4296021

1 file changed

Lines changed: 55 additions & 11 deletions

File tree

python/sglang/srt/models/hunyuan.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,42 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
206206
return final_hidden_states.view(orig_shape)
207207

208208

209+
def get_head_dim(config):
210+
if hasattr(config, "head_dim"):
211+
return int(config.head_dim)
212+
if hasattr(config, "attention_head_dim"):
213+
return int(config.attention_head_dim)
214+
215+
# since some hunyuan model don't follow the self.hidden_size // self.total_num_heads rule
216+
# wrong setting may cause runtime error, just throw error if this field is missing.
217+
raise ValueError("Missing head dim config, try set head_dim in config.json")
218+
219+
220+
def check_head_dim(config):
221+
# Some models may lack `head_dim` and use `attention_head_dim` instead.
222+
# This attribute is also used by flashinfer_backend.py, so we check for
223+
# consistency and raise an error if it's not met to avoid silent failures.
224+
# Although we could adapt the HunYuan model to use `attention_head_dim`,
225+
# flashinfer expects `head_dim`, so we enforce its presence for correctness.
226+
calc_head_dim = config.hidden_size // config.num_attention_heads
227+
228+
if hasattr(config, "attention_head_dim"):
229+
if calc_head_dim != config.attention_head_dim and not hasattr(
230+
config, "head_dim"
231+
):
232+
# in this case, flash infer(and other components may calculate wrong value.)
233+
raise ValueError(
234+
f"HunYuan model config error: calculated head_dim {calc_head_dim} != attention_head_dim {config.attention_head_dim}"
235+
+ f"\nPlease Add head_dim:{config.attention_head_dim} in config.json to make sure correctly inference."
236+
)
237+
238+
if hasattr(config, "head_dim") and config.attention_head_dim != config.head_dim:
239+
raise ValueError(
240+
f"HunYuan model config error: head_dim({config.head_dim}) != attention_head_dim({config.attention_head_dim})"
241+
+ f"\nPlease change head_dim:{config.attention_head_dim} in config.json to make sure correctly inference."
242+
)
243+
244+
209245
class HunYuanAttention(nn.Module):
210246

211247
def __init__(
@@ -240,9 +276,11 @@ def __init__(
240276
assert tp_size % self.total_num_kv_heads == 0
241277
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
242278
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
243-
self.head_dim = getattr(
244-
config, "head_dim", self.hidden_size // self.total_num_heads
245-
)
279+
# Prioritize `head_dim` but fall back to `attention_head_dim` for Hunyuan models.
280+
self.head_dim = get_head_dim(config)
281+
282+
check_head_dim(config)
283+
246284
self.q_size = self.num_heads * self.head_dim
247285
self.kv_size = self.num_kv_heads * self.head_dim
248286
self.scaling = self.head_dim**-0.5
@@ -493,7 +531,6 @@ def forward(
493531
hidden_states = self.get_input_embeddings(input_ids)
494532
residual = None
495533

496-
cla_factor = _get_cla_factor(self.config)
497534
prev_kv_states = None
498535
for i in range(len(self.layers)):
499536
layer = self.layers[i]
@@ -560,6 +597,11 @@ def __init__(
560597
if config.tie_word_embeddings:
561598
self.lm_head.weight = self.model.embed_tokens.weight
562599

600+
self.hidden_size = config.hidden_size
601+
self.head_dim = get_head_dim(config)
602+
603+
check_head_dim(config)
604+
563605
logit_scale = getattr(config, "logit_scale", 1.0)
564606
self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale)
565607
self.sampler = Sampler()
@@ -582,16 +624,14 @@ def _split_qkv_weight(self, qkv: torch.Tensor):
582624
self.config, "num_key_value_heads", self.config.num_attention_heads
583625
)
584626
num_key_value_groups = num_attention_heads // num_kv_heads
585-
hidden_size = self.config.hidden_size
586-
attention_head_dim = self.config.hidden_size // num_attention_heads
587627

588628
qkv = qkv.reshape(
589-
num_kv_heads, num_key_value_groups + 2, attention_head_dim, hidden_size
629+
num_kv_heads, num_key_value_groups + 2, self.head_dim, self.hidden_size
590630
)
591631
q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1)
592-
q = q.reshape(-1, hidden_size)
593-
k = k.reshape(-1, hidden_size)
594-
v = v.reshape(-1, hidden_size)
632+
q = q.reshape(-1, self.hidden_size)
633+
k = k.reshape(-1, self.hidden_size)
634+
v = v.reshape(-1, self.hidden_size)
595635
return torch.concat((q, k, v))
596636
# return qkv.reshape((num_kv_heads, num_key_value_groups+2 , attention_head_dim, hidden_size)).permute((1,0,2,3)).reshape((-1, hidden_size)),
597637

@@ -768,4 +808,8 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
768808
)
769809

770810

771-
EntryClass = HunYuanMoEV1ForCausalLM
811+
class HunYuanDenseV1ForCausalLM(HunYuanMoEV1ForCausalLM):
812+
pass
813+
814+
815+
EntryClass = [HunYuanMoEV1ForCausalLM, HunYuanDenseV1ForCausalLM]

0 commit comments

Comments
 (0)