Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
disable_mistral_community_chat_template: bool = False,
sentence_transformers_dense_modules: bool = False,
fuse_gate_up_exps: bool = False):
fuse_gate_up_exps: bool = False,
fuse_qkv: bool = False):
if type(self) is ModelBase or \
type(self) is TextModel or \
type(self) is MmprojModel:
Expand All @@ -139,6 +140,10 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
self.fuse_gate_up_exps = fuse_gate_up_exps
self._gate_exp_buffer: dict[int, Tensor] = {}
self._up_exp_buffer: dict[int, Tensor] = {}
self.fuse_qkv = fuse_qkv
self._q_buffer: dict[int, Tensor] = {}
self._k_buffer: dict[int, Tensor] = {}
self._v_buffer: dict[int, Tensor] = {}
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
self.metadata_override = metadata_override
Expand Down Expand Up @@ -551,6 +556,33 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid):
return []

# Handle Q/K/V tensor fusion if enabled
if self.fuse_qkv and bid is not None:
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_Q, bid):
self._q_buffer[bid] = data_torch
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_K, bid):
self._k_buffer[bid] = data_torch
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_V, bid):
self._v_buffer[bid] = data_torch

# Check if all three Q, K, V are buffered for this layer
if bid in self._q_buffer and bid in self._k_buffer and bid in self._v_buffer:
q_data = self._q_buffer.pop(bid)
k_data = self._k_buffer.pop(bid)
v_data = self._v_buffer.pop(bid)
# Q shape: (n_embd_q, n_embd), K shape: (n_embd_k, n_embd), V shape: (n_embd_v, n_embd)
# concatenate to (n_embd_q + n_embd_k + n_embd_v, n_embd)
fused_data = torch.cat([q_data, k_data, v_data], dim=0)
fused_name = self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_QKV, bid)
logger.info(f"Fused Q, K, V into QKV for layer {bid}")
return [(fused_name, fused_data)]

# If we buffered a Q/K/V tensor, wait for the others
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_Q, bid) or \
self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_K, bid) or \
self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_V, bid):
return []

return [(new_name, data_torch)]

def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
Expand Down Expand Up @@ -12293,6 +12325,11 @@ def parse_args() -> argparse.Namespace:
help="Fuse gate_exps and up_exps tensors into a single gate_up_exps tensor for MoE models.",
)

parser.add_argument(
"--fuse-qkv", action="store_true",
help="Fuse separate Q, K, V weight tensors into a single QKV tensor.",
)

args = parser.parse_args()
if not args.print_supported_models and args.model is None:
parser.error("the following arguments are required: model")
Expand Down Expand Up @@ -12431,7 +12468,8 @@ def main() -> None:
small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules,
fuse_gate_up_exps=args.fuse_gate_up_exps
fuse_gate_up_exps=args.fuse_gate_up_exps,
fuse_qkv=args.fuse_qkv
)

if args.vocab_only:
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
Expand Down Expand Up @@ -1702,6 +1703,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
Expand Down Expand Up @@ -1780,6 +1782,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
Expand Down
166 changes: 166 additions & 0 deletions scripts/fuse_qkv_gguf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
#!/usr/bin/env python3
"""Fuse Q/K/V tensors in an existing GGUF file into a single QKV tensor.

This script operates at the binary level to preserve ALL metadata (including
tokenizer) byte-for-byte from the original file.

Usage:
python scripts/fuse_qkv_gguf.py input.gguf output.gguf
"""
import sys, struct, os, re
import numpy as np

sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'gguf-py'))
from gguf import GGUFReader


def align_offset(offset, alignment=32):
return (offset + alignment - 1) // alignment * alignment


def write_tensor_info(f, name, n_dims, dims, tensor_type, data_offset):
"""Write one tensor info entry in GGUF format."""
name_bytes = name.encode('utf-8')
f.write(struct.pack('<Q', len(name_bytes)))
f.write(name_bytes)
f.write(struct.pack('<I', n_dims))
for d in dims:
f.write(struct.pack('<Q', d))
f.write(struct.pack('<I', tensor_type))
f.write(struct.pack('<Q', data_offset))


def main():
if len(sys.argv) < 3:
print(f"Usage: {sys.argv[0]} input.gguf output.gguf")
sys.exit(1)

input_path = sys.argv[1]
output_path = sys.argv[2]

print(f"Reading {input_path}...")
reader = GGUFReader(input_path)

with open(input_path, 'rb') as f:
magic = f.read(4)
version = struct.unpack('<I', f.read(4))[0]
n_tensors_orig = struct.unpack('<Q', f.read(8))[0]
n_kv = struct.unpack('<Q', f.read(8))[0]

print(f" Version: {version}, Tensors: {n_tensors_orig}, KV fields: {n_kv}")

ti_start = min(t.field.offset for t in reader.tensors)
kv_data_start = 24
kv_data_end = ti_start

print(f" KV data: {kv_data_start} - {kv_data_end} ({kv_data_end - kv_data_start} bytes)")

tensor_map = {t.name: t for t in reader.tensors}

qkv_pattern = re.compile(r'^blk\.(\d+)\.attn_([qkv])\.weight$')
layer_qkv = {}
fused_names = set()

for name, t in tensor_map.items():
m = qkv_pattern.match(name)
if m:
layer = int(m.group(1))
qkv_type = m.group(2)
if layer not in layer_qkv:
layer_qkv[layer] = {}
layer_qkv[layer][qkv_type] = t
fused_names.add(name)

fuse_layers = sorted([l for l, d in layer_qkv.items() if len(d) == 3])
print(f" Fusing {len(fuse_layers)} layers: {fuse_layers[0]}-{fuse_layers[-1]}")

output_tensors = []
seen_layers = set()

for t in reader.tensors:
if t.name in fused_names:
m = qkv_pattern.match(t.name)
layer = int(m.group(1))
if layer in seen_layers:
continue
seen_layers.add(layer)

q = layer_qkv[layer]['q']
k = layer_qkv[layer]['k']
v = layer_qkv[layer]['v']

assert q.tensor_type == k.tensor_type == v.tensor_type

q_dims = [int(x) for x in q.field.parts[3]]
k_dims = [int(x) for x in k.field.parts[3]]
v_dims = [int(x) for x in v.field.parts[3]]

assert q_dims[0] == k_dims[0] == v_dims[0]

fused_ne0 = q_dims[0]
fused_ne1 = q_dims[1] + k_dims[1] + v_dims[1]
fused_name = f"blk.{layer}.attn_qkv.weight"

fused_data = np.concatenate([
np.array(q.data, copy=True),
np.array(k.data, copy=True),
np.array(v.data, copy=True),
])

print(f" Layer {layer}: Q{q_dims}+K{k_dims}+V{v_dims} -> QKV[{fused_ne0},{fused_ne1}] {fused_data.nbytes} bytes")

output_tensors.append((fused_name, 2, [fused_ne0, fused_ne1],
int(q.tensor_type), fused_data.tobytes()))
else:
dims = [int(x) for x in t.field.parts[3]]
n_dims = int(t.field.parts[2][0])
output_tensors.append((t.name, n_dims, dims,
int(t.tensor_type), bytes(t.data)))

n_tensors_new = len(output_tensors)
print(f"\n {n_tensors_orig} -> {n_tensors_new} tensors")

with open(input_path, 'rb') as f:
f.seek(kv_data_start)
kv_data_bytes = f.read(kv_data_end - kv_data_start)

print(f"\nWriting {output_path}...")
alignment = 32

with open(output_path, 'wb') as f:
f.write(magic)
f.write(struct.pack('<I', version))
f.write(struct.pack('<Q', n_tensors_new))
f.write(struct.pack('<Q', n_kv))
f.write(kv_data_bytes)

data_offsets = []
current_data_offset = 0
for name, n_dims, dims, ttype, data in output_tensors:
aligned = align_offset(current_data_offset, alignment)
data_offsets.append(aligned)
current_data_offset = aligned + len(data)

for i, (name, n_dims, dims, ttype, data) in enumerate(output_tensors):
write_tensor_info(f, name, n_dims, dims, ttype, data_offsets[i])

ti_section_end = f.tell()
tensor_data_start = align_offset(ti_section_end, alignment)
if tensor_data_start > ti_section_end:
f.write(b'\x00' * (tensor_data_start - ti_section_end))

for i, (name, n_dims, dims, ttype, data) in enumerate(output_tensors):
current_pos = f.tell() - tensor_data_start
target_pos = data_offsets[i]
if target_pos > current_pos:
f.write(b'\x00' * (target_pos - current_pos))
f.write(data)

final_size = f.tell()

print(f" Output size: {final_size / 1e9:.2f} GB")
print(" Done!")


if __name__ == '__main__':
main()
3 changes: 3 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_OUTPUT,
LLM_TENSOR_ROPE_FREQS,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_QKV,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_K,
LLM_TENSOR_ATTN_V,
Expand Down Expand Up @@ -759,6 +760,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_QKV,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_K,
LLM_TENSOR_ATTN_V,
Expand Down Expand Up @@ -960,6 +962,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_OUTPUT,
LLM_TENSOR_CLS_OUT,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_QKV,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K,
Expand Down
45 changes: 35 additions & 10 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2708,6 +2708,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags);
}
};

// helper: try merged QKV first, fall back to separate Q, K, V
auto create_tensor_qkv = [&](llama_layer & layer, int bid, int64_t n_embd_, int64_t n_embd_head_k_, int64_t n_head_, int64_t n_embd_k_gqa_, int64_t n_embd_v_gqa_) {
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", bid), {n_embd_, n_embd_head_k_ * n_head_ + n_embd_k_gqa_ + n_embd_v_gqa_}, TENSOR_NOT_REQUIRED);
if (layer.wqkv == nullptr) {
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", bid), {n_embd_, n_embd_head_k_ * n_head_}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", bid), {n_embd_, n_embd_k_gqa_}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", bid), {n_embd_, n_embd_v_gqa_}, 0);
}
};
switch (arch) {
case LLM_ARCH_LLAMA:
case LLM_ARCH_REFACT:
Expand All @@ -2733,9 +2743,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {

layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);

layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
// only LLaMA-family archs have fused QKV inference graph support
if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_LLAMA_EMBED) {
create_tensor_qkv(layer, i, n_embd, n_embd_head_k, n_head, n_embd_k_gqa, n_embd_v_gqa);
} else {
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
}
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);

// optional bias tensors
Expand Down Expand Up @@ -3556,12 +3571,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
auto & layer = layers[i];

layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);

layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
// only Qwen2 arch has fused QKV inference graph support
if (arch == LLM_ARCH_QWEN2) {
create_tensor_qkv(layer, i, n_embd, n_embd_head_k, n_head, n_embd_k_gqa, n_embd_v_gqa);
} else {
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
}
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);


// optional bias tensors
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
Expand Down Expand Up @@ -3645,9 +3665,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {

layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);

layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
// only Qwen3 arch has fused QKV inference graph support
if (arch == LLM_ARCH_QWEN3) {
create_tensor_qkv(layer, i, n_embd, n_embd_head_k, n_head, n_embd_gqa, n_embd_gqa);
} else {
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
}
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);

layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
Expand Down
Loading