Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ struct gpt_params {
float rope_freq_base = 0.0f; // RoPE base frequency
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
float yarn_beta_fast = 32.0f; // YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
float yarn_attn_factor = -1.0f; // YaRN magnitude scaling factor
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change required?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This copies from mainline.

float yarn_beta_fast = -1.0f; // YaRN low correction dim
float yarn_beta_slow = -1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = -1.0f; // KV cache defragmentation threshold

Expand Down
101 changes: 78 additions & 23 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
# NOTE: if you get an error here, you need to update the convert_hf_to_gguf_update.py script
# or pull the latest version of the model from Huggingface
# don't edit the hashes manually!
if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273":
# ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
res = "grok-2"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe"
Expand Down Expand Up @@ -1905,57 +1908,109 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return tensors


@Model.register("GrokForCausalLM")
@Model.register("GrokForCausalLM", "Grok1ForCausalLM")
class GrokModel(Model):
model_arch = gguf.MODEL_ARCH.GROK

def set_vocab(self):
self._set_vocab_sentencepiece()
if (self.dir_model / 'tokenizer.model').is_file():
self._set_vocab_sentencepiece()
return

if not (self.dir_model / 'tokenizer.json').is_file() or not (self.dir_model / 'chat_template.jinja').is_file():
logger.error('Error: Missing vocab and chat template, download files from https://huggingface.co/alvarobartt/grok-2-tokenizer')
sys.exit(1)

self._set_vocab_gpt2()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def set_gguf_parameters(self):
super().set_gguf_parameters()

_experts: list[dict[str, Tensor]] | None = None
self.gguf_writer.add_attn_logit_softcapping(self.hparams.get("attn_logit_softcapping", 30.0))
self.gguf_writer.add_router_logit_softcapping(self.hparams.get("router_logit_softcapping", 30.0))
if (final_logit_softcap := self.hparams.get("final_logit_softcapping")):
self.gguf_writer.add_final_logit_softcapping(final_logit_softcap)

if (rope_dim := self.hparams.get("head_dim")) is None:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]

if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)

# Treat "original" as "yarn", seems to have been a mistake
if self.hparams.get("rope_type") in ("yarn", "original"):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(self.hparams["scaling_factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["original_max_position_embeddings"])
self.gguf_writer.add_rope_scaling_yarn_ext_factor(self.hparams["extrapolation_factor"])
self.gguf_writer.add_rope_scaling_yarn_attn_factor(self.hparams["attn_factor"])
self.gguf_writer.add_rope_scaling_yarn_beta_fast(self.hparams["beta_fast"])
self.gguf_writer.add_rope_scaling_yarn_beta_slow(self.hparams["beta_slow"])

if temp_len := self.hparams.get("attn_temperature_len"):
self.gguf_writer.add_attn_temperature_length(temp_len)

self.gguf_writer.add_attn_output_scale(self.hparams.get("attn_output_multiplier", rope_dim**-0.5))
self.gguf_writer.add_embedding_scale(self.hparams["embedding_multiplier_scale"])
self.gguf_writer.add_logit_scale(self.hparams["output_multiplier_scale"])

_experts: list[dict[str, list[Tensor]]] | None = None
_cur_expert = ""

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
tensors: list[tuple[str, Tensor]] = []
is_expert = ".moe." in name or ".block_sparse_moe.experts." in name

if not is_expert:
tensors.append((self.map_tensor_name(name), data_torch))

# process the experts separately
if name.find(".moe.") != -1:
if is_expert or self._cur_expert:
n_experts = self.hparams["num_local_experts"]

assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []
# concatenate split tensors
if name in self._experts[bid]:
self._cur_expert = name
self._experts[bid][name].append(data_torch)
return []
elif is_expert:
self._cur_expert = name
self._experts[bid][name] = [data_torch]
return []
else:
self._cur_expert = ""

# merge the experts into a single 3d tensor
for wid in ["linear", "linear_1", "linear_v"]:
datas: list[Tensor] = []
for bid in range(self.block_count):
if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
for wid in [("linear", "w1", 0), ("linear_1", "w2", 1), ("linear_v", "w3", 0)]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
for xid in range(n_experts):
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid[0]}.weight"
if ename not in self._experts[bid]:
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid[1]}.weight"
tensor_list = self._experts[bid][ename]
datas.append(torch.cat(tensor_list, dim=wid[2]) if len(tensor_list) > 1 else tensor_list[0])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)
data_torch = torch.stack(datas, dim=0)

merged_name = f"transformer.decoder_layer.{bid}.moe.{wid}.weight"
merged_name = f"transformer.decoder_layer.{bid}.moe.{wid[0]}.weight"

new_name = self.map_tensor_name(merged_name)
new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []
yield (new_name, data_torch)

return [(self.map_tensor_name(name), data_torch)]
yield from tensors


@Model.register("DbrxForCausalLM")
Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2", },
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902", },
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890", },
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
]


Expand Down
24 changes: 16 additions & 8 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class LLM:
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
ROUTER_LOGIT_SOFTCAPPING = "{arch}.router_logit_softcapping"

class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
Expand All @@ -112,16 +113,22 @@ class Attention:
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
SLIDING_WINDOW = "{arch}.attention.sliding_window"
OUTPUT_SCALE = "{arch}.attention.output_scale"
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"

class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
FREQ_BASE = "{arch}.rope.freq_base"
SCALING_TYPE = "{arch}.rope.scaling.type"
SCALING_FACTOR = "{arch}.rope.scaling.factor"
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
DIMENSION_COUNT = "{arch}.rope.dimension_count"
FREQ_BASE = "{arch}.rope.freq_base"
SCALING_TYPE = "{arch}.rope.scaling.type"
SCALING_FACTOR = "{arch}.rope.scaling.factor"
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor"
SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor"
SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast"
SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow"

class Split:
LLM_KV_SPLIT_NO = "split.no"
Expand Down Expand Up @@ -540,6 +547,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_POST_NORM,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.GPTNEOX: [
Expand Down
21 changes: 21 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,9 @@ def add_logit_scale(self, value: float) -> None:
def add_attn_logit_softcapping(self, value: float) -> None:
self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)

def add_router_logit_softcapping(self, value: float) -> None:
self.add_float32(Keys.LLM.ROUTER_LOGIT_SOFTCAPPING.format(arch=self.arch), value)

def add_final_logit_softcapping(self, value: float) -> None:
self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)

Expand Down Expand Up @@ -701,6 +704,12 @@ def add_relative_attn_buckets_count(self, value: int) -> None:
def add_sliding_window(self, value: int) -> None:
self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)

def add_attn_output_scale(self, value: float) -> None:
self.add_float32(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value)

def add_attn_temperature_length(self, value: int) -> None:
self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)

def add_pooling_type(self, value: PoolingType) -> None:
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)

Expand Down Expand Up @@ -728,6 +737,18 @@ def add_rope_scaling_finetuned(self, value: bool) -> None:
def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)

def add_rope_scaling_yarn_ext_factor(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_YARN_EXT_FACTOR.format(arch=self.arch), value)

def add_rope_scaling_yarn_attn_factor(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_YARN_ATTN_FACTOR.format(arch=self.arch), value)

def add_rope_scaling_yarn_beta_fast(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_YARN_BETA_FAST.format(arch=self.arch), value)

def add_rope_scaling_yarn_beta_slow(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_YARN_BETA_SLOW.format(arch=self.arch), value)

def add_ssm_conv_kernel(self, value: int) -> None:
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)

Expand Down
4 changes: 4 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class TensorNameMap:
"backbone.embedding", # mamba
"backbone.embeddings", # mamba-hf
"transformer.in_out_embed", # Grok
"model.layers.{bid}.pre_attn_norm", # grok-2
"embedding.word_embeddings", # chatglm
"transformer.token_embeddings", # openelm
"shared", # t5
Expand Down Expand Up @@ -202,6 +203,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
"encoder.layers.{bid}.norm1", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
"model.layers.{bid}.post_attn_norm", # grok-2
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
),

Expand Down Expand Up @@ -230,6 +232,7 @@ class TensorNameMap:
"h.{bid}.ln_2", # gpt2
"model.layers.{bid}.ffn_norm", # internlm2
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
"model.layers.{bid}.pre_moe_norm", # grok-2
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
"transformer.layers.{bid}.ffn_norm", # openelm
),
Expand All @@ -242,6 +245,7 @@ class TensorNameMap:
# Post feed-forward norm
MODEL_TENSOR.FFN_POST_NORM: (
"model.layers.{bid}.post_feedforward_layernorm", # gemma2
"model.layers.{bid}.post_moe_norm", # grok-2
),

MODEL_TENSOR.FFN_GATE_INP: (
Expand Down
9 changes: 8 additions & 1 deletion src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ enum llm_kv {
LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID,
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
LLM_KV_ROUTER_LOGIT_SOFTCAPPING,
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
LLM_KV_SWIN_NORM,
LLM_KV_RESCALE_EVERY_N_LAYERS,
Expand All @@ -123,7 +124,8 @@ enum llm_kv {
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
LLM_KV_ATTENTION_SLIDING_WINDOW,
LLM_KV_ATTENTION_SCALE,

LLM_KV_ATTENTION_OUTPUT_SCALE,
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_FREQ_BASE,
LLM_KV_ROPE_SCALE_LINEAR,
Expand All @@ -134,6 +136,11 @@ enum llm_kv {
LLM_KV_ROPE_SCALING_FINETUNED,
LLM_KV_ROPE_SCALING_YARN_LOG_MUL,

LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR,
LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR,
LLM_KV_ROPE_SCALING_YARN_BETA_FAST,
LLM_KV_ROPE_SCALING_YARN_BETA_SLOW,

LLM_KV_SPLIT_NO,
LLM_KV_SPLIT_COUNT,
LLM_KV_SPLIT_TENSORS_COUNT,
Expand Down
12 changes: 12 additions & 0 deletions src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
case LLAMA_VOCAB_PRE_TYPE_GROK_2:
regex_exprs = {
// original regex from tokenizer.json
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {
Expand Down Expand Up @@ -1973,6 +1980,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "kimi-k2") {
pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2;
clean_spaces = false;
}
else if (
tokenizer_pre == "grok-2") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2;
clean_spaces = false;
} else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
}
Expand Down
1 change: 1 addition & 0 deletions src/llama-vocab.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ enum llama_vocab_pre_type {
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37,
LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38,
LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39,
};

struct LLM_KV;
Expand Down
Loading