Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
res = "chatglm-bpe"
if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902":
# ref: https://huggingface.co/zai-org/GLM-4.5-Air, https://huggingface.co/zai-org/GLM-4.5
res = "gpt-2"
if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee":
# ref: https://huggingface.co/LumiOpen/Viking-7B
res = "viking"
Expand Down Expand Up @@ -3948,6 +3951,226 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
return [(self.map_tensor_name(name), data_torch)]
return super().modify_tensors(data_torch, name, bid)

@ModelBase.register("Glm4MoeForCausalLM")
class Glm4MoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.GLM4_MOE

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer)
self.block_count = self.hparams["num_hidden_layers"] + 1
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)

def set_vocab(self):
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
self.dir_model, trust_remote_code=True
)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

# Set special tokens
special_vocab._set_special_token(
"eos", tokenizer.get_added_vocab()["<|endoftext|>"]
)
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
special_vocab._set_special_token(
"unk", tokenizer.get_added_vocab()["<|endoftext|>"]
)
special_vocab._set_special_token(
"bos", tokenizer.get_added_vocab()["<|endoftext|>"]
)
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338

# Fix chat template syntax error in GLM-4.5 models
if special_vocab.chat_template and isinstance(special_vocab.chat_template, str):
# Fix multiple syntax issues in GLM-4.5 chat template
template = special_vocab.chat_template
# Fix nested double quotes issue
template = template.replace('endswith("/nothink")', "endswith('/nothink')")
# Fix any other potential parentheses/tuple issues
template = template.replace(
"not visible_text(m.content).endswith('/nothink'))",
"not visible_text(m.content).endswith('/nothink')"
)
special_vocab.chat_template = template
special_vocab.add_to_gguf(self.gguf_writer)

def set_gguf_parameters(self):
super().set_gguf_parameters()
if (rope_dim := self.hparams.get("head_dim")) is None:
rope_dim = (
self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
)
self.gguf_writer.add_rope_dimension_count(
int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))
)

# MoE parameters
if (n_experts := self.hparams.get("n_routed_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
# Note: expert_used_count is already set by parent class using num_experts_per_tok
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None:
self.gguf_writer.add_expert_shared_count(n_shared_experts)
if (first_k_dense_replace := self.hparams.get("first_k_dense_replace")) is not None:
self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace)

# Expert gating function (sigmoid for GLM4_MOE)
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)

# Routed scaling factor
if (routed_scaling_factor := self.hparams.get("routed_scaling_factor")) is not None:
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)

# Normalise topk probabilities
if (norm_topk_prob := self.hparams.get("norm_topk_prob")) is not None:
self.gguf_writer.add_expert_weights_norm(norm_topk_prob)

_experts: list[dict[str, Tensor]] | None = None
_shared_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(
self, data_torch: Tensor, name: str, bid: int | None
) -> Iterable[tuple[str, Tensor]]:
if name.startswith("model.visual."): # ignore visual part
return []
elif name.startswith("model.language_model."):
name = name.replace("language_model.", "") # for multimodal variants

# Handle main token embedding (but not layer-specific NextN embeddings)
if name == "model.embed_tokens.weight" and ".layers." not in name:
return [(self.map_tensor_name("token_embd.weight"), data_torch)]

# Handle routed experts
if name.find("mlp.experts") != -1 and "shared_experts" not in name:
n_experts = self.hparams["n_routed_experts"]
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

# Extend experts array if needed (for models where actual layers > num_hidden_layers)
while len(self._experts) <= bid:
self._experts.append({})

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)
# Generate GGUF tensor names for merged experts
if w_name == "down_proj":
new_name = f"blk.{bid}.ffn_down_exps.weight"
elif w_name == "gate_proj":
new_name = f"blk.{bid}.ffn_gate_exps.weight"
elif w_name == "up_proj":
new_name = f"blk.{bid}.ffn_up_exps.weight"
else:
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
return tensors
else:
return []

# Handle expert gating input (routing gate)
if ".mlp.gate.e_score_correction_bias" in name:
new_name = name.replace("model.layers.", "blk.").replace(
".mlp.gate.e_score_correction_bias", ".ffn_gate_inp.bias"
)
return [(new_name, data_torch)]
elif ".mlp.gate.weight" in name:
new_name = name.replace("model.layers.", "blk.").replace(
".mlp.gate.weight", ".ffn_gate_inp.weight"
)
return [(new_name, data_torch)]

# Handle shared expert tensors
if ".mlp.shared_experts." in name:
new_name = name.replace("model.layers.", "blk.").replace(".mlp.shared_experts.", ".ffn_")
if "gate_proj" in new_name:
new_name = new_name.replace("gate_proj", "gate_shexp")
elif "down_proj" in new_name:
new_name = new_name.replace("down_proj", "down_shexp")
elif "up_proj" in new_name:
new_name = new_name.replace("up_proj", "up_shexp")
return [(new_name, data_torch)]

# Handle regular dense FFN layers (for hybrid dense/MoE architecture)
if ".mlp." in name and "experts" not in name and "_shexp" not in name:
if "gate_proj" in name:
new_name = name.replace("model.layers.", "blk.").replace(
".mlp.gate_proj.weight", ".ffn_gate.weight"
)
elif "up_proj" in name:
new_name = name.replace("model.layers.", "blk.").replace(
".mlp.up_proj.weight", ".ffn_up.weight"
)
elif "down_proj" in name:
new_name = name.replace("model.layers.", "blk.").replace(
".mlp.down_proj.weight", ".ffn_down.weight"
)
else:
new_name = name
return [(self.map_tensor_name(new_name), data_torch)]

# Handle special NextN tensors - preserve for future MTP support
if (
".embed_tokens." in name
or ".shared_head." in name
or ".eh_proj." in name
or ".enorm." in name
or ".hnorm." in name
):
new_name = name.replace("model.layers.", "blk.").replace("model.", "").replace(".weight", "")
return [(new_name, data_torch)]

# GLM tensor mapping - handle directly without map_tensor_name
if ".input_layernorm." in name:
new_name = name.replace("model.layers.", "blk.").replace(".input_layernorm.", ".attn_norm.")
return [(new_name, data_torch)]
elif ".post_attention_layernorm." in name:
new_name = name.replace("model.layers.", "blk.").replace(".post_attention_layernorm.", ".ffn_norm.")
return [(new_name, data_torch)]
elif ".self_attn." in name:
# Map GLM self_attn to standard attention naming
new_name = name.replace("model.layers.", "blk.").replace(".self_attn.", ".attn_")
if "q_proj" in new_name:
new_name = new_name.replace("q_proj", "q")
elif "k_proj" in new_name:
new_name = new_name.replace("k_proj", "k")
elif "v_proj" in new_name:
new_name = new_name.replace("v_proj", "v")
elif "o_proj" in new_name:
new_name = new_name.replace("o_proj", "output")
return [(new_name, data_torch)]

return super().modify_tensors(data_torch, name, bid)

def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")

@Model.register("ChatGLMModel", "ChatGLMForConditionalGeneration")
class ChatGLMModel(Model):
Expand Down
51 changes: 51 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ class MODEL_ARCH(IntEnum):
OPENELM = auto()
ARCTIC = auto()
DEEPSEEK2 = auto()
GLM4_MOE = auto()
CHATGLM = auto()
BITNET = auto()
BITNET_25 = auto()
Expand Down Expand Up @@ -262,6 +263,9 @@ class MODEL_TENSOR(IntEnum):
FFN_GATE_EXP = auto()
FFN_DOWN_EXP = auto()
FFN_UP_EXP = auto()
FFN_GATE_EXPS = auto() # merged experts
FFN_DOWN_EXPS = auto() # merged experts
FFN_UP_EXPS = auto() # merged experts
FFN_GATE_SHEXP = auto()
FFN_DOWN_SHEXP = auto()
FFN_UP_SHEXP = auto()
Expand Down Expand Up @@ -314,6 +318,12 @@ class MODEL_TENSOR(IntEnum):
ENC_FFN_DOWN = auto()
ENC_FFN_UP = auto()
ENC_OUTPUT_NORM = auto()
NEXTN_EH_PROJ = auto() # nextn tensors (glm4moe)
NEXTN_EMBED_TOKENS = auto() # nextn tensors (glm4moe)
NEXTN_ENORM = auto() # nextn tensors (glm4moe)
NEXTN_HNORM = auto() # nextn tensors (glm4moe)
NEXTN_SHARED_HEAD_HEAD = auto() # nextn tensors (glm4moe)
NEXTN_SHARED_HEAD_NORM = auto() # nextn tensors (glm4moe)


MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
Expand Down Expand Up @@ -358,6 +368,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.ARCTIC: "arctic",
MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.GLM4_MOE: "glm4moe",
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.BITNET_25: "bitnet-25",
MODEL_ARCH.T5: "t5",
Expand Down Expand Up @@ -404,6 +415,9 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
MODEL_TENSOR.FFN_GATE_EXPS: "blk.{bid}.ffn_gate_exps", # merged experts
MODEL_TENSOR.FFN_DOWN_EXPS: "blk.{bid}.ffn_down_exps", # merged experts
MODEL_TENSOR.FFN_UP_EXPS: "blk.{bid}.ffn_up_exps", # merged experts
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
Expand Down Expand Up @@ -451,6 +465,13 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down",
MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up",
MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
# NextN/MTP tensors (GLM4_MOE)
MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.eh_proj",
MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.embed_tokens",
MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.enorm",
MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.hnorm",
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.shared_head.head",
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.shared_head.norm",
}

MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
Expand Down Expand Up @@ -1070,6 +1091,36 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.GLM4_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE, # dense layers
MODEL_TENSOR.FFN_DOWN, # dense layers
MODEL_TENSOR.FFN_UP, # dense layers
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXPS,
MODEL_TENSOR.FFN_DOWN_EXPS,
MODEL_TENSOR.FFN_UP_EXPS,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
# NextN/MTP tensors - preserved but unused
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.BITNET: [
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
Expand Down
Loading