Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ __global__ void per_token_quant_fp8_kernel(

float warp_max = warpReduceMax(max_value);

__shared__ float scale;
scale = warp_max / FP8_E4M3_MAX;
// NOTE: one CTA has multiple warps (each warp handles one token), so `scale`
// must be per-warp/per-thread (register) instead of a single shared variable.
const float scale = warp_max / FP8_E4M3_MAX;
// Broadcast scale
if (lane_id == 0) {
token_scale[0] = scale;
}
float scale_inv = (scale == 0.f) ? 0.f : 1.0f / scale;
const float scale_inv = (scale == 0.f) ? 0.f : 1.0f / scale;

//
// Pass-2: quantize and write back
Expand Down
Loading