Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
a337ebd
model : Initial support for DeepseekV32ForCausalLM (for now with dens…
sszymczy Mar 11, 2026
e467684
model : added indexer q and k calculation in DeepseekV32ForCausalLM.
sszymczy Mar 12, 2026
723f0ce
ggml : add Hadamard transform GGML OP and implementation
sszymczy Mar 12, 2026
72b7214
kv-cache : add cache for indexer keys (temporary solution)
sszymczy Mar 13, 2026
961bc95
convert : DSA indexer weights are bf16 in the original fp8 model, so …
sszymczy Mar 14, 2026
9a63e7a
model : crude proof-of-concept implementation of the DSA indexer for …
sszymczy Mar 14, 2026
3eb340e
ggml : add CUDA Hadamard transformation implementation (borrowed from…
sszymczy Mar 15, 2026
08dc7fd
ggml : add new GGML_OP_WHERE_ID (akin to torch where but using indices)
sszymczy Mar 15, 2026
998f496
model : used new GGML_OP_WHERE_ID op in DeepSeek V3.2 lightning index…
sszymczy Mar 15, 2026
6c9d773
model : handle multiple streams in DeepSeek V3.2 lightning indexer
sszymczy Mar 16, 2026
cb94b56
ggml : handle multiple streams in CUDA GGML_OP_WHERE_ID implementation
sszymczy Mar 16, 2026
02c2159
kv-cache : fix crashes for models without indexer
sszymczy Mar 16, 2026
e7aa89a
model : replaced ggml_argsort_top_k with ggml_top_k in DeepSeek V3.2 …
sszymczy Mar 22, 2026
1874ac9
model : added comments in DeepSeek V3.2 lightning indexer implementat…
sszymczy Mar 23, 2026
4309c84
kv-cache : added llama_kv_cache_dsa KV cache specific to DSA composed…
sszymczy Mar 24, 2026
9b0a4ee
ggml : replaced GGML_OP_WHERE_ID with GGML_OP_SCATTER that works simi…
sszymczy Mar 24, 2026
0ee5d80
ggml : added inplace version of GGML_OP_SCATTER and tests for this OP
sszymczy Mar 24, 2026
7f5578f
gguf-py : removed obsolete KV_B tensor from DEEPSEEK32 arch
sszymczy Mar 24, 2026
54945c7
convert : make pyright happy
sszymczy Mar 24, 2026
5677f08
ggml : added f16 version of GGML_OP_SCATTER
sszymczy Mar 25, 2026
1c830a1
ggml : added f16 version of GGML_OP_FILL
sszymczy Mar 25, 2026
83a0313
model : GGML_OP_SCATTER AND GGML_OP_FILL now work with f16 data, so w…
sszymczy Mar 25, 2026
6011bdd
ggml : fix bug in CUDA Hadamard transform implementation
sszymczy Mar 27, 2026
4aec6a8
ggml : simplified testing for nh being power of 2 in Hadamard transfo…
sszymczy Mar 27, 2026
a74d83a
ggml : added test for GGML_OP_HADAMARD
sszymczy Mar 27, 2026
5b9ce6c
convert : check if add_bos_token is true when converting DeepseekV32F…
sszymczy Mar 27, 2026
57a8def
Merge remote-tracking branch 'upstream/master' into deepseek-dsa
sszymczy Mar 31, 2026
6959bcf
graph : replaced llama_ik_cache with llama_kv_cache instance created …
sszymczy Apr 1, 2026
f443d0c
graph : implemented llm_graph_input_attn_k_dsa
sszymczy Apr 1, 2026
d3236d8
graph : renamed DSA-related suffixes, since in DSA-related classes _b…
sszymczy Apr 1, 2026
346c2b4
Merge remote-tracking branch 'upstream/master' into deepseek-dsa
sszymczy Apr 2, 2026
5086217
llama : handle LLM_ARCH_DEEPSEEK32 in test-llama-archs
sszymczy Apr 2, 2026
a7820f6
model : replace ggml_hadamard() in DEEPSEEK32 with Hadamard rotation …
sszymczy Apr 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,8 @@ def prepare_tensors(self):
gguf.MODEL_TENSOR.SSM_CONV1D_Q,
gguf.MODEL_TENSOR.SSM_CONV1D_K,
gguf.MODEL_TENSOR.SSM_CONV1D_V,
# DSA indexer weights should be F32
gguf.MODEL_TENSOR.INDEXER_PROJ,
)
)
or new_name[-7:] not in (".weight", ".lora_a", ".lora_b")
Expand Down Expand Up @@ -8737,6 +8739,147 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register(
"DeepseekV32ForCausalLM",
)
class DeepseekV32Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK32

# TODO @ngxson : remove this when we support MTP for deepseek models
skip_mtp = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0)
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)

def set_vocab(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
assert tokenizer.add_bos_token, "Change value of add_bos_token to true in tokenizer_config.json file."
self._set_vocab_gpt2()

def set_gguf_parameters(self):

# note: deepseek32 using MLA converts into MQA (ie: GQA with 1 group)
self.hparams["num_key_value_heads"] = 1

super().set_gguf_parameters()
hparams = self.hparams

# first_k_dense_replace: number of leading layers using dense FFN instead of MoE
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])

# note: deepseek32 using MLA converts into MQA with larger heads, then decompresses to MHA
self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length(hparams["kv_lora_rank"])
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])

# MoE parameters (required by C++ code for DEEPSEEK32 arch)
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"])

self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])

if (rope_mscale_all := self.rope_parameters.get("mscale_all_dim")) is not None:
# [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
# note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
# ref https://github.com/ggml-org/llama.cpp/pull/17945
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_mscale_all)

# NextN/MTP prediction layers
if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers)

# DSA indexer parameters
self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"])
self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"])
self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"])

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("language_model."):
name = name.replace("language_model.", "")

# rename e_score_correction_bias tensors
if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")

# skip Multi-Token Prediction (MTP) layers
if self.skip_mtp:
block_count = self.hparams["num_hidden_layers"]
match = re.match(r"model.layers.(\d+)", name)
if match and int(match.group(1)) >= block_count:
return

# process the experts separately
if name.find("mlp.experts") != -1:
n_experts = self.hparams["n_routed_experts"]
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"

yield from super().modify_tensors(data_torch, merged_name, bid)
return
else:
return

# note: MLA with the absorption optimization, needs these two split and k_b_proj transposed
if name.endswith("kv_b_proj.weight"):
name_kb = name.replace("kv_b_proj", "k_b_proj")
name_vb = name.replace("kv_b_proj", "v_b_proj")

n_head_kv = self.hparams["num_key_value_heads"]
v_head_dim = self.hparams["v_head_dim"]
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]

assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)

kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
k_b = k_b.transpose(1, 2)

yield from super().modify_tensors(k_b, name_kb, bid)
yield from super().modify_tensors(v_b, name_vb, bid)
return

yield from super().modify_tensors(data_torch, name, bid)

def prepare_tensors(self):
super().prepare_tensors()

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register(
"Mistral3ForConditionalGeneration",
"Ministral3ForCausalLM",
Expand Down
13 changes: 13 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ extern "C" {
GGML_OP_RWKV_WKV7,
GGML_OP_SOLVE_TRI,
GGML_OP_GATED_DELTA_NET,
GGML_OP_SCATTER,

GGML_OP_UNARY,

Expand Down Expand Up @@ -2481,6 +2482,18 @@ extern "C" {
struct ggml_tensor * beta,
struct ggml_tensor * state);

GGML_API struct ggml_tensor * ggml_scatter(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * ids,
float c);

GGML_API struct ggml_tensor * ggml_scatter_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * ids,
float c);

// custom operators

typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -2031,6 +2031,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_gated_delta_net(params, tensor);
} break;
case GGML_OP_SCATTER:
{
ggml_compute_forward_scatter(params, tensor);
} break;
case GGML_OP_MAP_CUSTOM1:
{
ggml_compute_forward_map_custom1(params, tensor);
Expand Down Expand Up @@ -2350,6 +2354,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
case GGML_OP_SCATTER:
{
n_tasks = n_threads;
} break;
Expand Down
181 changes: 180 additions & 1 deletion ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2232,8 +2232,42 @@ static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, gg
}
}

static void ggml_compute_forward_fill_f16(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_fp16_t c = GGML_CPU_FP32_TO_FP16(ggml_get_op_params_f32(dst, 0));

GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
GGML_TENSOR_LOCALS(size_t, nb, dst, nb);

const auto [ir0, ir1] = get_thread_range(params, dst);

for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne2*ne1);
const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);

ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);

ggml_vec_set_f16(ne0, dst_ptr, c);
}
}

void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
ggml_compute_forward_fill_f32(params, dst);
const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_fill_f32(params, dst);
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_fill_f16(params, dst);
} break;
default:
{
GGML_ABORT("unsupported type for ggml_compute_forward_fill: %s", ggml_type_name(src0->type));
}
}
}

// ggml_compute_tri
Expand Down Expand Up @@ -11205,3 +11239,148 @@ void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_
}
}
}

// ggml_compute_forward_scatter

static void ggml_compute_forward_scatter_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];

const float c = ggml_get_op_params_f32(dst, 0);
const bool inplace = ggml_get_op_params_i32(dst, 1);

GGML_ASSERT(ggml_are_same_shape(src0, dst));

GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_I32);

GGML_ASSERT(src0->nb[0] == sizeof(float));

const int ith = params->ith;
const int nth = params->nth;

const int nr = ggml_nrows(src0);

GGML_TENSOR_BINARY_OP_LOCALS

GGML_ASSERT( nb0 == sizeof(float));
GGML_ASSERT(nb00 == sizeof(float));

// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

for (int ir = ir0; ir < ir1; ++ir) {
// src0 indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);

const float * src0_ptr = (float *) ((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 );
const int32_t * ids_ptr = (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );

// copy whole row from src0
if (!inplace) {
ggml_vec_cpy_f32(ne00, dst_ptr, src0_ptr);
}

// set dst elements indicated by indices in src1 to c
for (int j = 0; j < ne10; ++j) {
int id = ids_ptr[j];
GGML_ASSERT(id >= 0 && id < ne00);
dst_ptr[id] = c;
}
}
}

static void ggml_compute_forward_scatter_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];

const ggml_fp16_t c = GGML_CPU_FP32_TO_FP16(ggml_get_op_params_f32(dst, 0));
const bool inplace = ggml_get_op_params_i32(dst, 1);

GGML_ASSERT(ggml_are_same_shape(src0, dst));

GGML_ASSERT(dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_I32);

GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));

const int ith = params->ith;
const int nth = params->nth;

const int nr = ggml_nrows(src0);

GGML_TENSOR_BINARY_OP_LOCALS

GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));

// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

for (int ir = ir0; ir < ir1; ++ir) {
// src0 indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);

const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 );
const int32_t * ids_ptr = (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );

// copy whole row from src0
if (!inplace) {
// ggml_vec_cpy_f16(ne00, dst_ptr, src0_ptr)
for (int i = 0; i < ne00; ++i) {
dst_ptr[i] = src0_ptr[i];
}
}

// set dst elements indicated by indices in src1 to c
for (int j = 0; j < ne10; ++j) {
int id = ids_ptr[j];
GGML_ASSERT(id >= 0 && id < ne00);
dst_ptr[id] = c;
}
}
}

void ggml_compute_forward_scatter(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_scatter_f32(params, dst);
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_scatter_f16(params, dst);
} break;
default:
{
GGML_ABORT("unsupported type for ggml_compute_forward_scatter: %s", ggml_type_name(src0->type));
}
}
}
Loading