Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 110 additions & 6 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8486,8 +8486,18 @@ def set_vocab(self):
class NemotronHModel(GraniteHybridModel):
"""Hybrid mamba2/attention model from NVIDIA"""
model_arch = gguf.MODEL_ARCH.NEMOTRON_H
is_moe: bool = False

def __init__(self, *args, **kwargs):
# We have to determine the correct model architecture (MoE vs non-MoE) before
# calling the parent __init__. This is because the parent constructor
# uses self.model_arch to build the tensor name map, and all MoE-specific
# mappings would be missed if it were called with the default non-MoE arch.
hparams = ModelBase.load_hparams(args[0], self.is_mistral_format)
if "num_experts_per_tok" in hparams:
self.model_arch = gguf.MODEL_ARCH.NEMOTRON_H_MOE
self.is_moe = True

super().__init__(*args, **kwargs)

# Save the top-level head_dim for later
Expand All @@ -8499,9 +8509,11 @@ def __init__(self, *args, **kwargs):

# Update the ssm / attn / mlp layers
# M: Mamba2, *: Attention, -: MLP
# MoE:
# M: Mamba2, *: Attention, E: Expert
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"]
self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"]
self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == ("E" if self.is_moe else "-")]

def get_attn_layers(self):
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
Expand All @@ -8517,18 +8529,110 @@ def set_gguf_parameters(self):
# Set feed_forward_length
# NOTE: This will trigger an override warning. This is preferrable to
# duplicating all the parent logic
n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"])
self.gguf_writer.add_feed_forward_length([
n_ff if i in self._mlp_layers else 0 for i in range(self.block_count)
])
if not self.is_moe:
n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"])
self.gguf_writer.add_feed_forward_length([
n_ff if i in self._mlp_layers else 0 for i in range(self.block_count)
])
else:
moe_intermediate_size = self.hparams["moe_intermediate_size"]
self.gguf_writer.add_feed_forward_length([
moe_intermediate_size if i in self._mlp_layers else 0 for i in range(self.block_count)
])
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_shared_feed_forward_length(self.hparams["moe_shared_expert_intermediate_size"])
self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"])
self.gguf_writer.add_expert_shared_count(self.hparams["n_shared_experts"])
self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"])
self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_group_count(self.hparams["n_group"])

# number of experts used per token (top-k)
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)

def set_vocab(self):
super().set_vocab()

# The tokenizer _does_ add a BOS token (via post_processor type
# TemplateProcessing) but does not set add_bos_token to true in the
# config, so we need to explicitly override it here.
self.gguf_writer.add_add_bos_token(True)
if not self.is_moe:
self.gguf_writer.add_add_bos_token(True)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if self.is_moe and bid is not None:
if name.endswith("mixer.gate.e_score_correction_bias"):
new_name = name.replace("e_score_correction_bias", "e_score_correction.bias")
mapped_name = self.map_tensor_name(new_name)
return [(mapped_name, data_torch)]

if name.endswith("mixer.dt_bias"):
new_name = name.replace("dt_bias", "dt.bias")
mapped_name = self.map_tensor_name(new_name)
return [(mapped_name, data_torch)]

if name.endswith("mixer.conv1d.weight"):
squeezed_data = data_torch.squeeze()
mapped_name = self.map_tensor_name(name)
return [(mapped_name, squeezed_data)]

if name.endswith("mixer.A_log"):
transformed_data = -torch.exp(data_torch)
reshaped_data = transformed_data.squeeze().reshape(-1, 1)
mapped_name = self.map_tensor_name(name)
return [(mapped_name, reshaped_data)]

if name.endswith("mixer.D"):
reshaped_data = data_torch.squeeze().reshape(-1, 1)
mapped_name = self.map_tensor_name(name)
return [(mapped_name, reshaped_data)]

if name.endswith("mixer.norm.weight"):
reshaped_data = data_torch.reshape(8, 512)
mapped_name = self.map_tensor_name(name)
return [(mapped_name, reshaped_data)]

if name.find("mixer.experts") != -1:
n_experts = self.hparams["n_routed_experts"]
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 2:
# merge the experts into a single tensor
tensors: list[tuple[str, Tensor]] = []
for w_name in ["down_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"backbone.layers.{bid}.mixer.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))

return tensors
else:
return []

return super().modify_tensors(data_torch, name, bid)

def prepare_tensors(self):
super().prepare_tensors()

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("BailingMoeForCausalLM")
Expand Down
29 changes: 29 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ class MODEL_ARCH(IntEnum):
JAIS = auto()
NEMOTRON = auto()
NEMOTRON_H = auto()
NEMOTRON_H_MOE = auto()
EXAONE = auto()
EXAONE4 = auto()
GRANITE = auto()
Expand Down Expand Up @@ -786,6 +787,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.NEMOTRON_H: "nemotron_h",
MODEL_ARCH.NEMOTRON_H_MOE: "nemotron_h_moe",
MODEL_ARCH.EXAONE: "exaone",
MODEL_ARCH.EXAONE4: "exaone4",
MODEL_ARCH.GRANITE: "granite",
Expand Down Expand Up @@ -2529,6 +2531,33 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.NEMOTRON_H_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_OUT,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
# experts
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
# shared expert
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.EXAONE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
9 changes: 7 additions & 2 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ class TensorNameMap:
"model.layers.{bid}.feed_forward.gate", # lfm2moe
"model.layers.{bid}.mlp.router.gate", # afmoe
"layers.{bid}.gate", # mistral-large
"backbone.layers.{bid}.mixer.gate", # nemotron-h-moe
),

MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
Expand All @@ -390,6 +391,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.expert_bias", # afmoe
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
"model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2
"backbone.layers.{bid}.mixer.gate.e_score_correction" # nemotron-h-moe
),

# Feed-forward up
Expand Down Expand Up @@ -438,7 +440,7 @@ class TensorNameMap:
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) ernie4.5-moe
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) ernie4.5-moe, nemotron-h-moe (merged)
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
Expand All @@ -452,6 +454,7 @@ class TensorNameMap:
"model.layers.{bid}.feed_forward.down_proj",
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
"layers.{bid}.shared_experts.w3", # mistral-large
"backbone.layers.{bid}.mixer.shared_experts.up_proj", # nemotron-h-moe
),

MODEL_TENSOR.FFN_UP_CHEXP: (
Expand Down Expand Up @@ -546,7 +549,7 @@ class TensorNameMap:
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) ernie4.5-moe
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) ernie4.5-moe nemotron-h-moe (merged)
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
Expand All @@ -561,6 +564,7 @@ class TensorNameMap:
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
"layers.{bid}.shared_experts.w2", # mistral-large
"backbone.layers.{bid}.mixer.shared_experts.down_proj", # nemotron-h-moe
),

MODEL_TENSOR.FFN_DOWN_CHEXP: (
Expand Down Expand Up @@ -704,6 +708,7 @@ class TensorNameMap:
"model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.dt_proj", # plamo2
"model.layers.{bid}.linear_attn.dt_proj", # qwen3next
"backbone.layers.{bid}.mixer.dt", # nemotron-h-moe
),

MODEL_TENSOR.SSM_DT_NORM: (
Expand Down
35 changes: 35 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_NEMOTRON, "nemotron" },
{ LLM_ARCH_NEMOTRON_H, "nemotron_h" },
{ LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" },
{ LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_EXAONE4, "exaone4" },
{ LLM_ARCH_RWKV6, "rwkv6" },
Expand Down Expand Up @@ -1763,6 +1764,39 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_NEMOTRON_H_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
// mamba(2) ssm layers
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
// attention layers
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
// dense FFN
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
// MoE FFN (for MoE layers)
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_EXP_PROBS_B,"blk.%d.exp_probs_b" },
// MoE shared expert layer
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_EXAONE,
{
Expand Down Expand Up @@ -2817,6 +2851,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
case LLM_ARCH_LFM2:
case LLM_ARCH_LFM2MOE:
case LLM_ARCH_NEMOTRON_H:
case LLM_ARCH_NEMOTRON_H_MOE:
case LLM_ARCH_QWEN3NEXT:
return true;
default:
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ enum llm_arch {
LLM_ARCH_JAIS,
LLM_ARCH_NEMOTRON,
LLM_ARCH_NEMOTRON_H,
LLM_ARCH_NEMOTRON_H_MOE,
LLM_ARCH_EXAONE,
LLM_ARCH_EXAONE4,
LLM_ARCH_RWKV6,
Expand Down
9 changes: 9 additions & 0 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cur = ggml_relu(ctx0, cur);
cb(cur, "ffn_moe_relu", il);
} break;
case LLM_FFN_RELU_SQR:
if (gate_exps) {
// TODO: add support for gated squared relu
GGML_ABORT("fatal error: gated squared relu not implemented");
} else {
cur = ggml_relu(ctx0, cur);
cur = ggml_sqr(ctx0, cur);
cb(cur, "ffn_moe_relu_sqr", il);
} break;
default:
GGML_ABORT("fatal error");
}
Expand Down
Loading
Loading