-
Notifications
You must be signed in to change notification settings - Fork 15.9k
ggml-cuda: add mem check for fusion #19916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -119,6 +119,18 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * | |
| } | ||
| } | ||
|
|
||
| // Sanitize NaN to -FLT_MAX so the iterative argmax produces unique expert IDs. | ||
| // NaN comparisons always return false, which would cause the same expert to be | ||
| // selected repeatedly. -FLT_MAX compares normally and is still excluded by the | ||
| // -INFINITY sentinel used after each selection round. | ||
| // More relevant for the cuBLAS path. See https://github.com/ggml-org/llama.cpp/issues/19659 | ||
| #pragma unroll | ||
| for (int i = 0; i < experts_per_thread; i++) { | ||
| if (__isnanf(wt[i])) { | ||
| wt[i] = -FLT_MAX; | ||
| } | ||
| } | ||
|
Comment on lines
+122
to
+132
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Shouldn't we be fine with fmaxf, so long as |
||
|
|
||
| // selection_wt is only needed when bias is present (selection uses wt + bias) | ||
| // when no bias, we use wt directly for both selection and weight values | ||
| float selection_wt[has_bias ? experts_per_thread : 1]; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would maybe be slightly simpler but either way is fine I think.