Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 67 additions & 2 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3366,6 +3366,69 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
return false;
}

// returns whether the write (out) nodes overwrite the read nodes in operation
static bool ggml_cuda_check_fusion_memory_ranges(ggml_cgraph * cgraph,
int node_idx,
int node_count,
int * out_nodes,
int out_count) {
auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
const int64_t a_start = (int64_t) a->data;
const int64_t a_end = a_start + ggml_nbytes(a);

const int64_t b_start = (int64_t) b->data;
const int64_t b_end = b_start + ggml_nbytes(b);

if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
return true;
}

return false;
Comment on lines +3382 to +3386
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
return true;
}
return false;
return (b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end);

This would maybe be slightly simpler but either way is fine I think.

};

bool is_ok = true;
// for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok
if (ggml_nrows(cgraph->nodes[node_idx]) == 1) {
return true;
}

for (int i = 0; i < out_count; ++i) {
const ggml_tensor * dst = cgraph->nodes[out_nodes[i]];

for (int j = node_idx; j < node_idx + node_count; ++j) {
// Loop over all srcs of all nodes in the fusion. If the src overlaps
// the destination and the src is not an intermediate node that's being
// elided, then disable fusion.

for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
const ggml_tensor * src = cgraph->nodes[j]->src[src_idx];

if (!src || src->op == GGML_OP_NONE) {
continue;
}

if (nodes_overlap(dst, src)) {
bool found = false;

for (int k = node_idx; k < j; ++k) {
if (cgraph->nodes[k] == src) {
found = true;
break;
}
}

if (!found) {
is_ok = false;
break;
}
}
}
}
}

return is_ok;
}

static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
bool graph_evaluated_or_captured = false;

Expand Down Expand Up @@ -3562,7 +3625,8 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
out_nodes[1] = i + ops.size() - 1;

if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) {
ggml_cuda_should_use_topk_moe(node, logits, weights, ids) &&
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;
Expand All @@ -3577,7 +3641,8 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud

int out_nodes[2] = { i + 1, i + 5 };
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) {
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) &&
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;
Expand Down
12 changes: 12 additions & 0 deletions ggml/src/ggml-cuda/topk-moe.cu
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,18 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
}
}

// Sanitize NaN to -FLT_MAX so the iterative argmax produces unique expert IDs.
// NaN comparisons always return false, which would cause the same expert to be
// selected repeatedly. -FLT_MAX compares normally and is still excluded by the
// -INFINITY sentinel used after each selection round.
// More relevant for the cuBLAS path. See https://github.com/ggml-org/llama.cpp/issues/19659
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
if (__isnanf(wt[i])) {
wt[i] = -FLT_MAX;
}
}
Comment on lines +122 to +132
Copy link
Collaborator

@ORippler ORippler Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. If the issue is in llama.cpp and not cuBLAS, I feel we should use fmaxf as a NaN-safe comparator: https://docs.nvidia.com/cuda/cuda-math-api/cuda_math_api/group__CUDA__MATH__SINGLE.html#_CPPv45fmaxfff (I presume we are talking about val_s > max_val_s later on in this kernel?)
  2. If the issue is in cuBLAS, I'd love more details so I can ask the cuBLAS team/take a look myself

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Yes but it's not just val_s > max_val_s it's val_s > max_val_s || (val_s == max_val_s && expert < max_expert)
  2. The linked issue has a repro. It's cuBLAS + Nemotron, so think it would be fun for you guys to look at :)

Copy link
Collaborator

@ORippler ORippler Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but it's not just val_s > max_val_s it's val_s > max_val_s || (val_s == max_val_s && expert < max_expert)

Shouldn't we be fine with fmaxf, so long as max_val & max_val_s are initialized to -FLT_MAX instead of -INFINITY at the beginning of the selection-loop over n_expert_used? At least for the case where k non-NAN values exist inside the logits for a given row. But at this point we are just pulling your proposal into the loop itself 😄


// selection_wt is only needed when bias is present (selection uses wt + bias)
// when no bias, we use wt directly for both selection and weight values
float selection_wt[has_bias ? experts_per_thread : 1];
Expand Down
Loading