Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 55 additions & 11 deletions python/sglang/srt/models/hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,42 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return final_hidden_states.view(orig_shape)


def get_head_dim(config):
if hasattr(config, "head_dim"):
return int(config.head_dim)
if hasattr(config, "attention_head_dim"):
return int(config.attention_head_dim)

# since some hunyuan model don't follow the self.hidden_size // self.total_num_heads rule
# wrong setting may cause runtime error, just throw error if this field is missing.
raise ValueError("Missing head dim config, try set head_dim in config.json")


def check_head_dim(config):
# Some models may lack `head_dim` and use `attention_head_dim` instead.
# This attribute is also used by flashinfer_backend.py, so we check for
# consistency and raise an error if it's not met to avoid silent failures.
# Although we could adapt the HunYuan model to use `attention_head_dim`,
# flashinfer expects `head_dim`, so we enforce its presence for correctness.
calc_head_dim = config.hidden_size // config.num_attention_heads

if hasattr(config, "attention_head_dim"):
if calc_head_dim != config.attention_head_dim and not hasattr(
config, "head_dim"
):
# in this case, flash infer(and other components may calculate wrong value.)
raise ValueError(
f"HunYuan model config error: calculated head_dim {calc_head_dim} != attention_head_dim {config.attention_head_dim}"
+ f"\nPlease Add head_dim:{config.attention_head_dim} in config.json to make sure correctly inference."
)

if hasattr(config, "head_dim") and config.attention_head_dim != config.head_dim:
raise ValueError(
f"HunYuan model config error: head_dim({config.head_dim}) != attention_head_dim({config.attention_head_dim})"
+ f"\nPlease change head_dim:{config.attention_head_dim} in config.json to make sure correctly inference."
)


class HunYuanAttention(nn.Module):

def __init__(
Expand Down Expand Up @@ -240,9 +276,11 @@ def __init__(
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(
config, "head_dim", self.hidden_size // self.total_num_heads
)
# Prioritize `head_dim` but fall back to `attention_head_dim` for Hunyuan models.
self.head_dim = get_head_dim(config)

check_head_dim(config)

self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
Expand Down Expand Up @@ -493,7 +531,6 @@ def forward(
hidden_states = self.get_input_embeddings(input_ids)
residual = None

cla_factor = _get_cla_factor(self.config)
prev_kv_states = None
for i in range(len(self.layers)):
layer = self.layers[i]
Expand Down Expand Up @@ -560,6 +597,11 @@ def __init__(
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight

self.hidden_size = config.hidden_size
self.head_dim = get_head_dim(config)

check_head_dim(config)

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale)
self.sampler = Sampler()
Expand All @@ -582,16 +624,14 @@ def _split_qkv_weight(self, qkv: torch.Tensor):
self.config, "num_key_value_heads", self.config.num_attention_heads
)
num_key_value_groups = num_attention_heads // num_kv_heads
hidden_size = self.config.hidden_size
attention_head_dim = self.config.hidden_size // num_attention_heads

qkv = qkv.reshape(
num_kv_heads, num_key_value_groups + 2, attention_head_dim, hidden_size
num_kv_heads, num_key_value_groups + 2, self.head_dim, self.hidden_size
)
q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1)
q = q.reshape(-1, hidden_size)
k = k.reshape(-1, hidden_size)
v = v.reshape(-1, hidden_size)
q = q.reshape(-1, self.hidden_size)
k = k.reshape(-1, self.hidden_size)
v = v.reshape(-1, self.hidden_size)
return torch.concat((q, k, v))
# return qkv.reshape((num_kv_heads, num_key_value_groups+2 , attention_head_dim, hidden_size)).permute((1,0,2,3)).reshape((-1, hidden_size)),

Expand Down Expand Up @@ -768,4 +808,8 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
)


EntryClass = HunYuanMoEV1ForCausalLM
class HunYuanDenseV1ForCausalLM(HunYuanMoEV1ForCausalLM):
pass


EntryClass = [HunYuanMoEV1ForCausalLM, HunYuanDenseV1ForCausalLM]
Loading