Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "6c81ce329e0802883b22eabab0d3fa48357337ef1ecb45443828bf1f6254833f":
# ref: https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B
res = "exaone-moe"
if chkhsh == "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4":
# ref: https://huggingface.co/Qwen/Qwen3.5-9B-Instruct
res = "qwen35"

if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -4359,7 +4362,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_mask_token_id(mask_token_id)


@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration", "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration")
class Qwen3VLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -4405,6 +4408,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
if name.startswith("model.language_model.") or name.startswith("lm_head."):
return

# Skip MTP tensors
if name.startswith("mtp."):
return

if name.startswith("model.visual."):
name = name.replace("model.visual.", "visual.", 1)

Expand Down Expand Up @@ -4538,6 +4545,53 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Qwen3_5ForConditionalGeneration")
Comment thread
CISC marked this conversation as resolved.
Outdated
class Qwen3_5TextModel(Qwen3NextModel):
model_arch = gguf.MODEL_ARCH.QWEN35

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Skip vision tensors - they go in the mmproj file
if name.startswith("model.visual."):
return

yield from super().modify_tensors(data_torch, name, bid)
Comment thread
CISC marked this conversation as resolved.
Outdated


@ModelBase.register("Qwen3_5MoeForConditionalGeneration")
class Qwen3_5MoeTextModel(Qwen3_5TextModel):
Comment thread
CISC marked this conversation as resolved.
Outdated
Comment thread
JJJYmmm marked this conversation as resolved.
Outdated
model_arch = gguf.MODEL_ARCH.QWEN35MOE

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
name = name.replace("language_model.", "")

if name.startswith("mtp."):
return

# NOTE: Qwen3.5MOE has native 3d experts FFN format, so no need to permute
if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"):
mapped = f"{name}.weight" if not name.endswith(".weight") else name
# Input: (n_expert=128, n_embd=2048, n_ff_exp=768)
# Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128}
yield (self.map_tensor_name(mapped), data_torch)
return

if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"):
if data_torch.ndim < 3 or data_torch.shape[-2] % 2 != 0:
raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}")
split_dim = data_torch.shape[-2] // 2
gate = data_torch[..., :split_dim, :].contiguous()
up = data_torch[..., split_dim:, :].contiguous()
base_name = name.removesuffix(".weight")
base = base_name.rsplit('.', 1)[0]
mapped_gate = f"{base}.gate_proj.weight"
mapped_up = f"{base}.up_proj.weight"
yield (self.map_tensor_name(mapped_gate), gate)
yield (self.map_tensor_name(mapped_up), up)
return

yield from super().modify_tensors(data_torch, name, bid)
Comment thread
JJJYmmm marked this conversation as resolved.
Outdated


@ModelBase.register("GPT2LMHeadModel")
class GPT2Model(TextModel):
model_arch = gguf.MODEL_ARCH.GPT2
Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
{"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
{"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", },
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", }
]

# some models are known to be broken upstream, so we will skip them as exceptions
Expand Down
66 changes: 64 additions & 2 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ class MODEL_ARCH(IntEnum):
QWEN3NEXT = auto()
QWEN3VL = auto()
QWEN3VLMOE = auto()
QWEN35 = auto()
QWEN35MOE = auto()
PHI2 = auto()
PHI3 = auto()
PHIMOE = auto()
Expand Down Expand Up @@ -557,13 +559,14 @@ class MODEL_TENSOR(IntEnum):
SSM_D = auto()
SSM_NORM = auto()
SSM_OUT = auto()
SSM_ALPHA = auto() # qwen3.5
SSM_BETA_ALPHA = auto() # qwen3next
SSM_CONV1D_Q = auto() # Kimi Linear
SSM_CONV1D_K = auto() # Kimi Linear
SSM_CONV1D_V = auto() # Kimi Linear
SSM_F_A = auto() # Kimi Linear
SSM_F_B = auto() # Kimi Linear
SSM_BETA = auto() # Kimi Linear
SSM_BETA = auto() # Kimi Linear qwen3.5
SSM_G_A = auto() # Kimi Linear
SSM_G_B = auto() # Kimi Linear
TIME_MIX_W0 = auto()
Expand Down Expand Up @@ -814,6 +817,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.QWEN3NEXT: "qwen3next",
MODEL_ARCH.QWEN3VL: "qwen3vl",
MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
MODEL_ARCH.QWEN35: "qwen35",
MODEL_ARCH.QWEN35MOE: "qwen35moe",
MODEL_ARCH.PHI2: "phi2",
MODEL_ARCH.PHI3: "phi3",
MODEL_ARCH.PHIMOE: "phimoe",
Expand Down Expand Up @@ -985,13 +990,14 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.SSM_ALPHA: "blk.{bid}.ssm_alpha", # qwen3.5
MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba",
MODEL_TENSOR.SSM_CONV1D_Q: "blk.{bid}.ssm_conv1d_q", # Kimi Linear
MODEL_TENSOR.SSM_CONV1D_K: "blk.{bid}.ssm_conv1d_k", # Kimi Linear
MODEL_TENSOR.SSM_CONV1D_V: "blk.{bid}.ssm_conv1d_v", # Kimi Linear
MODEL_TENSOR.SSM_F_A: "blk.{bid}.ssm_f_a", # Kimi Linear
MODEL_TENSOR.SSM_F_B: "blk.{bid}.ssm_f_b", # Kimi Linear
MODEL_TENSOR.SSM_BETA: "blk.{bid}.ssm_beta", # Kimi Linear
MODEL_TENSOR.SSM_BETA: "blk.{bid}.ssm_beta", # Kimi Linear qwen3.5
MODEL_TENSOR.SSM_G_A: "blk.{bid}.ssm_g_a", # Kimi Linear
MODEL_TENSOR.SSM_G_B: "blk.{bid}.ssm_g_b", # Kimi Linear
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
Expand Down Expand Up @@ -1818,6 +1824,62 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.QWEN35: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.ATTN_GATE,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.FFN_NORM,
Comment thread
JJJYmmm marked this conversation as resolved.
Outdated
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_BETA,
MODEL_TENSOR.SSM_ALPHA,
MODEL_TENSOR.SSM_OUT
],
MODEL_ARCH.QWEN35MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.ATTN_GATE,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_INP_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_BETA,
MODEL_TENSOR.SSM_ALPHA,
MODEL_TENSOR.SSM_OUT
],
MODEL_ARCH.PLAMO: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
15 changes: 14 additions & 1 deletion gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ class TensorNameMap:
"transformer_encoder.{bid}.qkv", # neobert
"layers.{bid}.attn.Wqkv", # modern-bert
"model.layers.{bid}.self_attn.language_expert_query_key_value", # cogvlm
"model.layers.{bid}.linear_attn.in_proj_qkv", # qwen3.5
),

# Attention query
Expand Down Expand Up @@ -359,6 +360,7 @@ class TensorNameMap:

MODEL_TENSOR.ATTN_GATE: (
"model.layers.{bid}.self_attn.gate_proj", # afmoe
"model.layers.{bid}.linear_attn.in_proj_z", # qwen3.5
"model.layers.{bid}.self_attn.g_proj", # step3.5 head-wise attention gate
),

Expand Down Expand Up @@ -823,6 +825,10 @@ class TensorNameMap:
"model.layers.layers.{bid}.mixer.out_proj", # plamo2
),

MODEL_TENSOR.SSM_ALPHA: (
"model.layers.{bid}.linear_attn.in_proj_a", # qwen3.5
),

MODEL_TENSOR.SSM_BETA_ALPHA: (
"model.layers.{bid}.linear_attn.in_proj_ba", # qwen3next
),
Expand All @@ -844,7 +850,8 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.f_b_proj",
),
MODEL_TENSOR.SSM_BETA: (
"model.layers.{bid}.self_attn.b_proj",
"model.layers.{bid}.linear_attn.in_proj_b", # qwen3.5
"model.layers.{bid}.self_attn.b_proj", # Kimi Linear
),
MODEL_TENSOR.SSM_G_A: (
"model.layers.{bid}.self_attn.g_a_proj",
Expand Down Expand Up @@ -1872,6 +1879,12 @@ class TensorNameMap:
"model.layers.{bid}.post_attention_layernorm",
),
},
MODEL_ARCH.QWEN35: {
MODEL_TENSOR.FFN_NORM: (),
},
MODEL_ARCH.QWEN35MOE: {
MODEL_TENSOR.FFN_NORM: (),
},
Comment thread
JJJYmmm marked this conversation as resolved.
Outdated
}

mapping: dict[str, tuple[MODEL_TENSOR, str]]
Expand Down
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ add_library(llama
models/qwen3vl-moe.cpp
models/qwen3moe.cpp
models/qwen3next.cpp
models/qwen35.cpp
models/qwen35moe.cpp
models/refact.cpp
models/rnd1.cpp
models/rwkv6-base.cpp
Expand Down
64 changes: 63 additions & 1 deletion src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN3NEXT, "qwen3next" },
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
{ LLM_ARCH_QWEN35, "qwen35" },
{ LLM_ARCH_QWEN35MOE, "qwen35moe" },
{ LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PHIMOE, "phimoe" },
Expand Down Expand Up @@ -366,6 +368,7 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" },
{ LLM_TENSOR_SSM_ALPHA, "blk.%d.ssm_alpha" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
Expand Down Expand Up @@ -968,7 +971,6 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_ATTN_QKV,
LLM_TENSOR_ATTN_GATE,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS,
Expand All @@ -985,6 +987,63 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
};
case LLM_ARCH_QWEN35:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K,
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_ATTN_V,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_ATTN_QKV,
LLM_TENSOR_ATTN_GATE,
LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP,
LLM_TENSOR_SSM_A_NOSCAN,
LLM_TENSOR_SSM_CONV1D,
LLM_TENSOR_SSM_DT,
LLM_TENSOR_SSM_BETA,
LLM_TENSOR_SSM_ALPHA,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
};
case LLM_ARCH_QWEN35MOE:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K,
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_ATTN_V,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_ATTN_QKV,
LLM_TENSOR_ATTN_GATE,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_GATE_SHEXP,
LLM_TENSOR_FFN_DOWN_SHEXP,
LLM_TENSOR_FFN_UP_SHEXP,
LLM_TENSOR_SSM_A_NOSCAN,
LLM_TENSOR_SSM_CONV1D,
LLM_TENSOR_SSM_DT,
LLM_TENSOR_SSM_BETA,
LLM_TENSOR_SSM_ALPHA,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
};
case LLM_ARCH_QWEN3VL:
case LLM_ARCH_CHAMELEON:
case LLM_ARCH_HUNYUAN_DENSE:
Expand Down Expand Up @@ -2456,6 +2515,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SSM_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SSM_BETA_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
Expand Down Expand Up @@ -2675,6 +2735,8 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
case LLM_ARCH_NEMOTRON_H_MOE:
case LLM_ARCH_QWEN3NEXT:
case LLM_ARCH_KIMI_LINEAR:
case LLM_ARCH_QWEN35:
case LLM_ARCH_QWEN35MOE:
return true;
default:
return false;
Expand Down
5 changes: 4 additions & 1 deletion src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ enum llm_arch {
LLM_ARCH_QWEN3NEXT,
LLM_ARCH_QWEN3VL,
LLM_ARCH_QWEN3VLMOE,
LLM_ARCH_QWEN35,
LLM_ARCH_QWEN35MOE,
LLM_ARCH_PHI2,
LLM_ARCH_PHI3,
LLM_ARCH_PHIMOE,
Expand Down Expand Up @@ -404,13 +406,14 @@ enum llm_tensor {
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next
LLM_TENSOR_SSM_ALPHA, // qwen3.5
// Kimi Linear KDA (using SSM_ prefix for consistency)
LLM_TENSOR_SSM_CONV1D_Q, // kimi: Q conv1d weight
LLM_TENSOR_SSM_CONV1D_K, // kimi: K conv1d weight
LLM_TENSOR_SSM_CONV1D_V, // kimi: V conv1d weight
LLM_TENSOR_SSM_F_A, // kimi: forget gate projection A
LLM_TENSOR_SSM_F_B, // kimi: forget gate projection B
LLM_TENSOR_SSM_BETA, // kimi: beta mixing coefficient
LLM_TENSOR_SSM_BETA, // kimi: beta mixing coefficient and qwen3.5
LLM_TENSOR_SSM_G_A, // kimi: output gate projection A
LLM_TENSOR_SSM_G_B, // kimi: output gate projection B
LLM_TENSOR_TIME_MIX_W0,
Expand Down
Loading
Loading