Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 106 additions & 16 deletions sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
Original file line number Diff line number Diff line change
@@ -1,18 +1,95 @@
#include <ATen/cuda/CUDAContext.h>

#include <cmath>
#include <cub/block/block_reduce.cuh>
#include <flashinfer/vec_dtypes.cuh>

#include "utils.h"

template <typename T>
static constexpr int kWarpSize = 32;

// ---------------------------------------------------------------------------
// 1. Warp‑local, no shared memory
// • One warp handles one token.
// • Eight tokens per 256‑thread CTA.
// ---------------------------------------------------------------------------
template <typename T, int kTokensPerCTA = 8, int kVecSize = 16>
__global__ void per_token_quant_fp8_kernel(
const T* __restrict__ input,
FP8_TYPE* __restrict__ output_q,
float* __restrict__ output_s,
const int64_t hidden_dim,
const int64_t num_tokens) {
const int warp_id = threadIdx.x / kWarpSize; // 0‑7 (8 warps)
const int lane_id = threadIdx.x & (kWarpSize - 1); // 0‑31
const int token_id = blockIdx.x * kTokensPerCTA + warp_id;
if (token_id >= num_tokens) return;

// Global tensors for this token
const T* token_input = input + token_id * hidden_dim;
FP8_TYPE* token_output = output_q + token_id * hidden_dim;
float* token_scale = output_s + token_id;

//
// Pass-1: Perform a warp reduce to find the max_value of a token's hidden_dim
//
float max_value = 0.f;
using vec_t = flashinfer::vec_t<T, kVecSize>;
const int32_t num_vec_elems = hidden_dim / kVecSize;

for (int32_t i = lane_id; i < num_vec_elems; i += kWarpSize) {
vec_t input_vec;
input_vec.cast_load(token_input + i * kVecSize);

#pragma unroll
for (uint32_t j = 0; j < kVecSize; ++j) {
max_value = fmaxf(max_value, fabsf(static_cast<float>(input_vec[j])));
}
}

float warp_max = warpReduceMax(max_value);

__shared__ float scale;
scale = warp_max / FP8_E4M3_MAX;
// Broadcast scale
if (lane_id == 0) {
token_scale[0] = scale;
}
float scale_inv = (scale == 0.f) ? 0.f : 1.0f / scale;

//
// Pass-2: quantize and write back
//
for (int i = lane_id; i < num_vec_elems; i += kWarpSize) {
vec_t input_vec;
input_vec.cast_load(token_input + i * kVecSize);
FP8_TYPE output_arr[kVecSize];
#pragma unroll
for (uint32_t j = 0; j < kVecSize; ++j) {
float val = static_cast<float>(input_vec[j]) * scale_inv;
val = fmaxf(fminf(val, FP8_E4M3_MAX), -FP8_E4M3_MAX);

#ifndef USE_ROCM
output_arr[j] = static_cast<FP8_TYPE>(val);
#else
output_arr[j] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
#endif
}
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
}
}

// ---------------------------------------------------------------------------
// 2. Baseline kernel (1 token / CTA, CUB block reduce)
// ---------------------------------------------------------------------------
template <typename T>
__global__ void per_token_quant_fp8_small_batch_kernel(
const T* __restrict__ input,
FP8_TYPE* __restrict__ output_q,
float* __restrict__ output_s,
const int64_t hidden_dim,
const int64_t num_tokens) {
const int token_idx = blockIdx.x;
if (token_idx >= num_tokens) return;

Expand Down Expand Up @@ -79,28 +156,41 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
CHECK_INPUT(input);
CHECK_INPUT(output_q);
CHECK_INPUT(output_s);

const auto input_sizes = input.sizes();
const int64_t num_tokens = input_sizes[0];
const int64_t hidden_dim = input_sizes[1];

TORCH_CHECK(hidden_dim % 16 == 0, "Hidden dimension must be divisible by 16, but got ", hidden_dim);

const int block_size = 256;
const int num_blocks = num_tokens;

dim3 grid(num_blocks);
dim3 block(block_size);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Hard-code sm_count
int sm_count = 132;
constexpr int TOKENS_PER_CTA = 8;
const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA);

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
per_token_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()),
static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
if (use_warp_kernel) {
// -------- warp‑local ---------------------------------------------------
constexpr int THREADS = TOKENS_PER_CTA * kWarpSize; // 256
dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA);
dim3 block(THREADS);
per_token_quant_fp8_kernel<scalar_t, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()),
static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
} else {
// -------- baseline -----------------------------------------------------
constexpr int THREADS = 256;
dim3 grid(num_tokens);
dim3 block(THREADS);
per_token_quant_fp8_small_batch_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()),
static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
}
return true;
});
}
Loading