Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -2464,6 +2464,8 @@ extern "C" {
bool lower,
bool uni);

// TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST]
// ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306
GGML_API struct ggml_tensor * ggml_gated_delta_net(
struct ggml_context * ctx,
struct ggml_tensor * q,
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10436,8 +10436,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(

const float * state_in_base = (const float *)src_state->data;

const int64_t rq1 = nev1 / neq1;
const int64_t rk1 = nev1 / nek1;
//const int64_t rq1 = nev1 / neq1;
//const int64_t rk1 = nev1 / nek1;
const int64_t rq3 = nev3 / neq3;
const int64_t rk3 = nev3 / nek3;

Expand All @@ -10447,8 +10447,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const int64_t iv1 = ir % H; // head_index
const int64_t iv3 = ir / H; // sequence

const int64_t iq1 = iv1 / rq1;
const int64_t ik1 = iv1 / rk1;
const int64_t iq1 = iv1 % neq1;
const int64_t ik1 = iv1 % nek1;

const int64_t iq3 = iv3 / rq3;
const int64_t ik3 = iv3 / rk3;
Expand Down
36 changes: 20 additions & 16 deletions ggml/src/ggml-cuda/gated_delta_net.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ __global__ void gated_delta_net_cuda(const float * q,
int64_t sb1,
int64_t sb2,
int64_t sb3,
int64_t rq1,
int64_t neqk1,
int64_t rq3,
float scale) {
const int64_t h_idx = blockIdx.x;
const int64_t sequence = blockIdx.y;
const int col = threadIdx.x; // each thread owns one column

const int64_t iq1 = h_idx / rq1;
const int64_t iq1 = h_idx % neqk1;
const int64_t iq3 = sequence / rq3;

const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
Expand Down Expand Up @@ -119,11 +119,11 @@ static void launch_gated_delta_net(
const float * q_d, const float * k_d, const float * v_d,
const float * g_d, const float * b_d, const float * s_d,
float * dst_d,
int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
int64_t sq1, int64_t sq2, int64_t sq3,
int64_t sv1, int64_t sv2, int64_t sv3,
int64_t sb1, int64_t sb2, int64_t sb3,
int64_t rq1, int64_t rq3,
int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
int64_t sq1, int64_t sq2, int64_t sq3,
int64_t sv1, int64_t sv2, int64_t sv3,
int64_t sb1, int64_t sb2, int64_t sb3,
int64_t neqk1, int64_t rq3,
float scale, cudaStream_t stream) {

dim3 grid_dims(H, n_seqs, 1);
Expand All @@ -134,19 +134,19 @@ static void launch_gated_delta_net(
gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
sb1, sb2, sb3, neqk1, rq3, scale);
break;
case 64:
gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
sb1, sb2, sb3, neqk1, rq3, scale);
break;
case 128:
gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
sb1, sb2, sb3, neqk1, rq3, scale);
break;
default:
GGML_ABORT("fatal error");
Expand All @@ -163,10 +163,12 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
ggml_tensor * src_state = dst->src[5];

GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb);
GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb);
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);

const int64_t S_v = nev0;
const int64_t H = nev1;
Expand All @@ -175,7 +177,9 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *

const bool kda = (src_g->ne[0] == S_v);

const int64_t rq1 = nev1 / neq1;
GGML_ASSERT(neq1 == nek1);
const int64_t neqk1 = neq1;

const int64_t rq3 = nev3 / neq3;

const float * q_d = (const float *) src_q->data;
Expand Down Expand Up @@ -214,10 +218,10 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
if (kda) {
launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale, stream);
sb1, sb2, sb3, neqk1, rq3, scale, stream);
} else {
launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale, stream);
sb1, sb2, sb3, neqk1, rq3, scale, stream);
}
}
4 changes: 3 additions & 1 deletion ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4999,7 +4999,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
#ifdef GGML_USE_MUSA
return false;
#else
return true;
KDA is faster using the AR kernel even when n_tokens >= 512.
//TODO: Add chunked kernel
return op->src[0]->ne[2] == 1 || op->src[3]->ne[0] == op->src[2]->ne[0];
#endif // GGML_USE_MUSA
case GGML_OP_FLASH_ATTN_EXT:
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
Expand Down
95 changes: 70 additions & 25 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ llama_context::llama_context(
cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO;

cparams.fused_gdn_ar = true;
cparams.fused_gdn_ch = false; // TODO: implement
cparams.fused_gdn_ch = true;
cparams.auto_fgdn = true;

// with causal attention, the batch size is limited by the context size
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
Expand Down Expand Up @@ -462,37 +463,81 @@ void llama_context::sched_reserve() {
cparams.auto_fa = false;
}

if (cparams.fused_gdn_ar) {
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
if (!gf) {
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check");
}
if (cparams.auto_fgdn) {
LLAMA_LOG_INFO("%s: resolving fused Gated Delta Net\n", __func__);

const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDNAR) + 1;
bool gdn_device_mismatch = false;
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * n = ggml_graph_node(gf, i);
if (n->op != GGML_OP_GATED_DELTA_NET) {
continue;
if (cparams.fused_gdn_ar) {
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
if (!gf) {
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (autoregressive)");
}
ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));

GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDNAR "-", prefix_len) == 0);
const int il = std::stoi(n->name + prefix_len);
ggml_backend_dev_t device_kv = model.dev_layer(il);
if (device_gdn != device_kv) {
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
"is assigned to device %s (usually due to missing support)\n",
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
gdn_device_mismatch = true;
break;
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_AR) + 1;
bool gdn_device_mismatch = false;
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * n = ggml_graph_node(gf, i);
if (n->op != GGML_OP_GATED_DELTA_NET) {
continue;
}
ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));

GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_AR "-", prefix_len) == 0);
const int il = std::stoi(n->name + prefix_len);
ggml_backend_dev_t device_kv = model.dev_layer(il);
if (device_gdn != device_kv) {
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
"is assigned to device %s (usually due to missing support)\n",
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
gdn_device_mismatch = true;
break;
}
}

if (gdn_device_mismatch) {
cparams.fused_gdn_ar = false;
LLAMA_LOG_WARN("%s: fused Gated Delta Net (autoregressive) not supported, set to disabled\n", __func__);
} else {
LLAMA_LOG_INFO("%s: fused Gated Delta Net (autoregressive) enabled\n", __func__);
}
}

if (gdn_device_mismatch) {
cparams.fused_gdn_ar = false;
LLAMA_LOG_WARN("%s: fused Gated Delta Net not supported, set to disabled\n", __func__);
if (cparams.fused_gdn_ch) {
// more than one token in the batch per sequence in order to take the chunked path
auto * gf = graph_reserve(16*n_seqs, n_seqs, n_outputs, mctx.get(), true);
if (!gf) {
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)");
}

const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_CH) + 1;
bool gdn_device_mismatch = false;
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * n = ggml_graph_node(gf, i);
if (n->op != GGML_OP_GATED_DELTA_NET) {
continue;
}
ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));

GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_CH "-", prefix_len) == 0);
const int il = std::stoi(n->name + prefix_len);
ggml_backend_dev_t device_kv = model.dev_layer(il);
if (device_gdn != device_kv) {
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
"is assigned to device %s (usually due to missing support)\n",
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
gdn_device_mismatch = true;
break;
}
}

if (gdn_device_mismatch) {
cparams.fused_gdn_ch = false;
LLAMA_LOG_WARN("%s: fused Gated Delta Net (chunked) not supported, set to disabled\n", __func__);
} else {
LLAMA_LOG_INFO("%s: fused Gated Delta Net (chunked) enabled\n", __func__);
}
}

cparams.auto_fgdn = false;
}

// reserve worst-case graph
Expand Down
1 change: 1 addition & 0 deletions src/llama-cparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct llama_cparams {
bool auto_fa;
bool fused_gdn_ar; // use fused gated delta net (autoregressive)
bool fused_gdn_ch; // use fused gated delta net (chunked)
bool auto_fgdn;
bool no_perf;
bool warmup;
bool op_offload;
Expand Down
6 changes: 3 additions & 3 deletions src/llama-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,6 @@ std::string llama_format_tensor_shape(const struct ggml_tensor * t);

std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i);

#define LLAMA_TENSOR_NAME_FATTN "__fattn__"
#define LLAMA_TENSOR_NAME_FGDNAR "__fgdnar__"
#define LLAMA_TENSOR_NAME_FGDNCH "__fgdnch__"
#define LLAMA_TENSOR_NAME_FATTN "__fattn__"
#define LLAMA_TENSOR_NAME_FGDN_AR "__fgdn_ar__"
#define LLAMA_TENSOR_NAME_FGDN_CH "__fgdn_ch__"
21 changes: 17 additions & 4 deletions src/models/delta-net-base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,23 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);

if (cparams.fused_gdn_ch) {
//ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s);
//cb(result, LLAMA_TENSOR_NAME_FGDNCH, il);
ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s);
cb(result, LLAMA_TENSOR_NAME_FGDN_CH, il);

ggml_tensor * output = ggml_view_4d(ctx0, result,
S_v, H_v, n_tokens, n_seqs,
ggml_row_size(result->type, S_v),
ggml_row_size(result->type, S_v * H_v),
ggml_row_size(result->type, S_v * H_v * n_tokens), 0);

GGML_ABORT("not implemented yet");
ggml_tensor * new_state = ggml_view_4d(ctx0, result,
S_v, S_v, H_v, n_seqs,
ggml_row_size(result->type, S_v),
ggml_row_size(result->type, S_v * S_v),
ggml_row_size(result->type, S_v * S_v * H_v),
ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs));

return {output, new_state};
}

const float scale = 1.0f / sqrtf(S_k);
Expand Down Expand Up @@ -327,7 +340,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne

if (cparams.fused_gdn_ar) {
ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s);
cb(result, LLAMA_TENSOR_NAME_FGDNAR, il);
cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il);

ggml_tensor * output = ggml_view_4d(ctx0, result,
S_v, H_v, n_tokens, n_seqs,
Expand Down
4 changes: 2 additions & 2 deletions src/models/qwen35.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,9 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);

// if head keys and value keys are different, repeat to force tensors into matching shapes
if (num_k_heads != num_v_heads) {
// note: need explicit repeat only if we are not using the fused GDN
if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) {
GGML_ASSERT(num_v_heads % num_k_heads == 0);
// TODO: try to avoid these explicit repeats by utilizing op broadcast
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
}
Expand Down
4 changes: 2 additions & 2 deletions src/models/qwen35moe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,9 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);

// if head keys and value keys are different, repeat to force tensors into matching shapes
if (num_k_heads != num_v_heads) {
// note: need explicit repeat only if we are not using the fused GDN
if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) {
GGML_ASSERT(num_v_heads % num_k_heads == 0);
// TODO: try to avoid these explicit repeats by utilizing op broadcast
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
}
Expand Down
1 change: 1 addition & 0 deletions src/models/qwen3next.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);

// if head keys and value keys are different, repeat to force tensors into matching shapes
// TODO: avoid repeats for fused GDN, needs broadcast configuration for GDN op [TAG_GGML_GDN_BCAST]
if (num_k_heads != num_v_heads) {
GGML_ASSERT(num_v_heads % num_k_heads == 0);
int64_t repeat_factor = num_v_heads / num_k_heads;
Expand Down
Loading