Skip to content
83 changes: 77 additions & 6 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,9 +736,10 @@ def __init__(self, *args, **kwargs):
else:
self.hf_arch = ""

if "text_config" in self.hparams:
llm_config_key = "lm_config" if "lm_config" in self.hparams else "text_config"
if llm_config_key in self.hparams:
# move the text_config to the root level
self.hparams = {**self.hparams, **self.hparams["text_config"]}
self.hparams = {**self.hparams, **self.hparams[llm_config_key]}

self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
Expand Down Expand Up @@ -1604,7 +1605,7 @@ class MmprojModel(ModelBase):
preprocessor_config: dict[str, Any]
global_config: dict[str, Any]

n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"]

has_vision_encoder: bool = True # by default
has_audio_encoder: bool = False
Expand All @@ -1621,11 +1622,12 @@ def __init__(self, *args, **kwargs):

# get n_embd of the text model
if not self.is_mistral_format:
if "text_config" not in self.hparams:
llm_config_key = "lm_config" if "lm_config" in self.hparams else "text_config"
if llm_config_key not in self.hparams:
self.hparams["text_config"] = {}
if "audio_config" not in self.hparams:
self.hparams["audio_config"] = {}
text_config = {**self.hparams, **self.hparams["text_config"]}
text_config = {**self.hparams, **self.hparams[llm_config_key]}
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
else:
text_config = {
Expand Down Expand Up @@ -1680,7 +1682,8 @@ def get_vision_config(self) -> dict[str, Any] | None:
return self.global_config.get(config_name)

def get_audio_config(self) -> dict[str, Any] | None:
return self.global_config.get("audio_config")
mm_config_key = "whisper_config" if "whisper_config" in self.hparams else "audio_config"
return self.global_config.get(mm_config_key)

def set_type(self):
self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)
Expand Down Expand Up @@ -2356,6 +2359,7 @@ def prepare_tensors(self):
"VLlama3ForCausalLM",
"LlavaForConditionalGeneration",
"VoxtralForConditionalGeneration",
"GlmasrModel",
"LlamaModel")
class LlamaModel(TextModel):
model_arch = gguf.MODEL_ARCH.LLAMA
Expand Down Expand Up @@ -2407,6 +2411,16 @@ def set_vocab(self):
# Apply to granite small models only
if self.hparams.get("vocab_size", 32000) == 49152:
self.gguf_writer.add_add_bos_token(False)
if isinstance(self.hparams.get("eos_token_id"), list):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab.add_to_gguf(self.gguf_writer)
special_vocab.chat_template = "glmedge"

def set_gguf_parameters(self):
super().set_gguf_parameters()
Expand Down Expand Up @@ -2443,6 +2457,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
"vision_language_adapter.",
"patch_merger.",
"pre_mm_projector_norm",
"audio_encoder.",
]

is_multimodal_tensor = "vision_tower" in name \
Expand Down Expand Up @@ -8999,6 +9014,62 @@ def __init__(self, *args, **kwargs):
raise NotImplementedError("Ultravox does not have text decoder. Instead, it uses Llama or other models for text. If you want to get the audio encoder, please use --mmproj argument")


@ModelBase.register("GlmasrModel")
class GlmASRWhisperEncoderModel(MmprojModel):
has_vision_encoder = False
has_audio_encoder = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if "hidden_size" not in self.hparams and "intermediate_size" not in self.hparams:
self.hparams["hidden_size"] = self.hparams["d_model"]
self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"]
self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"]

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GLMA)
self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"])
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))

def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".conv" in name and ".weight" in name:
return gguf.GGMLQuantizationType.F16
return super().tensor_force_quant(name, new_name, bid, n_dims)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if name.startswith("model.") or name.startswith("lm_head."):
# skip language model tensors
return []

if name.startswith("audio_encoder.whisper."):
name = name.replace("audio_encoder.whisper.","audio_tower.")
if "audio_encoder.layer_norm." in name or "audio_encoder.proj." in name:
name = name.replace("audio_encoder.", "audio_encoder.adapting.")

if name.startswith("audio_encoder.audio_bos_eos_token."):
return [(self.map_tensor_name("model.vision.boi"), data_torch[0]), (self.map_tensor_name("model.vision.eoi"), data_torch[1])]

if name.startswith("audio_encoder.adapting."):
name = name.replace("audio_encoder.adapting.","audio.multi_modal_projector.")
if ".layer_norm." in name:
name = name.replace(".layer_norm.", ".ln_pre.")
if ".0." in name:
name = name.replace(".0.", ".linear_1.")
if ".2." in name:
name = name.replace(".2.", ".linear_2.")
if ".proj." in name:
return []

if "conv1.bias" in name or "conv2.bias" in name:
# transpose conv1 and conv2 bias
data_torch = data_torch.unsqueeze(-1)

return [(self.map_tensor_name(name), data_torch)]


@ModelBase.register("Qwen2AudioForConditionalGeneration")
class WhisperEncoderModel(MmprojModel):
has_vision_encoder = False # no vision encoder
Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3320,6 +3320,7 @@ class VisionProjectorType:
ULTRAVOX = "ultravox"
INTERNVL = "internvl"
QWEN2A = "qwen2a" # audio
GLMA = "glma" # audio
QWEN25O = "qwen2.5o" # omni
VOXTRAL = "voxtral"
LFM2 = "lfm2"
Expand Down
2 changes: 2 additions & 0 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ enum projector_type {
PROJECTOR_TYPE_INTERNVL,
PROJECTOR_TYPE_LLAMA4,
PROJECTOR_TYPE_QWEN2A,
PROJECTOR_TYPE_GLMA,
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
PROJECTOR_TYPE_VOXTRAL,
PROJECTOR_TYPE_LFM2,
Expand All @@ -175,6 +176,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
{ PROJECTOR_TYPE_GLMA, "glma"},
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
{ PROJECTOR_TYPE_LFM2, "lfm2"},
Expand Down
44 changes: 43 additions & 1 deletion tools/mtmd/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ struct clip_model {
ggml_tensor * conv1d_2_w = nullptr;
ggml_tensor * conv1d_2_b = nullptr;
ggml_tensor * mm_norm_pre_w = nullptr;
ggml_tensor * mm_norm_pre_b = nullptr;
ggml_tensor * mm_norm_mid_w = nullptr;

// cogvlm
Expand Down Expand Up @@ -1829,7 +1830,6 @@ struct clip_graph {
GGML_ASSERT(model.layers[0].q_b);
GGML_ASSERT(model.layers[0].v_b);
GGML_ASSERT(!model.layers[0].k_b); // no bias for k
GGML_ASSERT(model.post_ln_w && model.post_ln_b);

ggml_tensor * pos_embd_selected = ggml_view_2d(
ctx0, model.position_embeddings,
Expand Down Expand Up @@ -1891,6 +1891,18 @@ struct clip_graph {
cur = ggml_gelu_erf(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);

} else if (ctx->proj_type() == PROJECTOR_TYPE_GLMA) {
cur = ggml_norm(ctx0, cur, hparams.eps);
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
cur = ggml_add(ctx0, cur, model.mm_norm_pre_b);
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * 4, cur->ne[1] / 4);
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
cur = ggml_add(ctx0, cur, model.mm_1_b);
cur = ggml_gelu_erf(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
cur = ggml_add(ctx0, cur, model.mm_2_b);
cur = ggml_concat(ctx0, model.mm_boi, cur, 1);
cur = ggml_concat(ctx0, cur, model.mm_eoi, 1);
} else {
GGML_ABORT("%s: unknown projector type", __func__);
}
Expand Down Expand Up @@ -2518,6 +2530,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_GLMA:
{
res = graph.build_whisper_enc();
} break;
Expand Down Expand Up @@ -3225,6 +3238,21 @@ struct clip_model_loader {
model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
} break;
case PROJECTOR_TYPE_GLMA:
{
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias"));
model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias"));
model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
model.mm_norm_pre_b = get_tensor(string_format(TN_MM_NORM_PRE, "bias"));
model.mm_boi = get_tensor(string_format(TN_TOK_BOI, "weight"));
model.mm_eoi = get_tensor(string_format(TN_TOK_EOI, "weight"));
} break;
case PROJECTOR_TYPE_LLAMA4:
{
model.mm_model_proj = get_tensor(TN_MM_PROJECTOR);
Expand Down Expand Up @@ -4606,6 +4634,16 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
n_patches /= 2;
}
} break;
case PROJECTOR_TYPE_GLMA:
{
n_patches = img->nx;
// whisper downscales input token by half after conv1d
n_patches /= 2;
// reshape by merge_factor
n_patches /= 4;
// for BOI and EOI token embeddings
n_patches += 2;
} break;
case PROJECTOR_TYPE_COGVLM:
{
n_patches += 2; // for BOI and EOI token embeddings
Expand Down Expand Up @@ -4941,6 +4979,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
case PROJECTOR_TYPE_IDEFICS3:
case PROJECTOR_TYPE_INTERNVL:
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_GLMA:
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_VOXTRAL:
Expand Down Expand Up @@ -5051,6 +5090,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_model_proj->ne[1];
case PROJECTOR_TYPE_QWEN2A:
return ctx->model.mm_fc_w->ne[1];
case PROJECTOR_TYPE_GLMA:
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
return ctx->model.mm_2_w->ne[1];
Expand Down Expand Up @@ -5097,6 +5138,7 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN2A
|| ctx->proj_type() == PROJECTOR_TYPE_GLMA
|| ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL;
}

Expand Down
Loading