Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions csrc/nv_internal/cpp/kernels/quantization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,14 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
}
}

template <typename T>
__global__ void block_scale_interleave_kernel(int numBatches, int numRows, int numRowsPadded,
int numCols, int numColsPadded, uint8_t const* SFIn,
uint8_t* SFOutput) {
int numCols, int numColsPadded, T const* SFIn,
T* SFOutput) {
for (int rowIdx = blockIdx.x; rowIdx < numRowsPadded; rowIdx += gridDim.x) {
for (int batchIdx = 0; batchIdx < numBatches; batchIdx++) {
for (int colIdx = threadIdx.x; colIdx < numColsPadded; colIdx += blockDim.x) {
uint8_t sf = 0;
T sf = 0;
if (rowIdx < numRows && colIdx < numCols) {
int64_t inOffset = batchIdx * numRows * numCols + rowIdx * numCols + colIdx;
sf = SFIn[inOffset];
Expand Down Expand Up @@ -287,19 +288,29 @@ __global__ void block_scale_interleave_reverse_kernel(int numBatches, int numRow
}

// This is intended for weight loading, so m and n are large, b <= 256
void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded,
uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount,
cudaStream_t stream) {
template <typename T>
void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, T const* SFIn,
T* SFOutput, int multiProcessorCount, cudaStream_t stream) {
// Each thread reads 1 int8 value
dim3 block(std::min(n_padded, 1024));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = std::max(1u, 4096u / block.x);
dim3 grid(std::min(m_padded, multiProcessorCount * numBlocksPerSM));

block_scale_interleave_kernel<<<grid, block, 0, stream>>>(b, m, m_padded, n, n_padded, SFIn,
SFOutput);
block_scale_interleave_kernel<T>
<<<grid, block, 0, stream>>>(b, m, m_padded, n, n_padded, SFIn, SFOutput);
}

// Explicit template instantiations for the types used by other compilation units
template void invokeBlockScaleInterleave<uint8_t>(int b, int m, int m_padded, int n, int n_padded,
uint8_t const* SFIn, uint8_t* SFOutput,
int multiProcessorCount, cudaStream_t stream);
template void invokeBlockScaleInterleave<__nv_bfloat16>(int b, int m, int m_padded, int n,
int n_padded, __nv_bfloat16 const* SFIn,
__nv_bfloat16* SFOutput,
int multiProcessorCount,
cudaStream_t stream);

// This is intended for weight loading, so m and n are large, b <= 256
void invokeBlockScaleInterleaveReverse(int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput,
int multiProcessorCount, cudaStream_t stream) {
Expand Down
6 changes: 3 additions & 3 deletions csrc/nv_internal/tensorrt_llm/kernels/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ void invokeSiluAndMulNVFP4Quantization(void* output, void* output_scale, void* i
void* input_global_scale, void* mask, bool use_silu_and_mul,
int m_topk, int k, int n_experts, cudaStream_t stream);

void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded,
uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount,
cudaStream_t stream = 0);
template <typename T>
void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, T const* SFIn,
T* SFOutput, int multiProcessorCount, cudaStream_t stream = 0);

void invokeBlockScaleInterleaveReverse(int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput,
int multiProcessorCount, cudaStream_t stream = 0);
Expand Down
79 changes: 58 additions & 21 deletions csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,41 @@ int computeSFIndex(int rowIdx, int colIdx, int totalRow, int totalColumn,
}
}

template <typename T>
void blockScaleInterleaveHost(TensorView blockScale, TensorView interleavedBlockScale) {
auto blockScaleShape = blockScale.sizes();
auto num_experts = blockScaleShape.size() == 3 ? blockScaleShape[0] : 1;
auto rows = blockScaleShape.size() == 3 ? blockScaleShape[1] : blockScaleShape[0];
auto cols = blockScaleShape.size() == 3 ? blockScaleShape[2] : blockScaleShape[1];

auto expert_out_size = tensorrt_llm::computeSwizzledLayoutSFSize(rows, cols);
auto rows_padded = PadUpFn(rows, 128);
auto cols_padded = PadUpFn(cols, 4);

for (int eIdx = 0; eIdx < static_cast<int>(num_experts); eIdx++) {
T* interleavedBlockScalePtr =
static_cast<T*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size;
for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx) {
auto globalRowIdx = eIdx * rows + rIdx;
T* blockScalePtr = static_cast<T*>(blockScale.data_ptr()) + globalRowIdx * cols;
for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) {
uint8_t sf_ori = 0;
if (rIdx < static_cast<int>(rows) && cIdx < static_cast<int>(cols)) {
sf_ori = blockScalePtr[cIdx];
}
int sf_index = computeSFIndex(rIdx, cIdx, rows, cols,
tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4);
interleavedBlockScalePtr[sf_index] = sf_ori;
}
}
}
}

template void blockScaleInterleaveHost<uint8_t>(TensorView blockScale,
TensorView interleavedBlockScale);
template void blockScaleInterleaveHost<__nv_bfloat16>(TensorView blockScale,
TensorView interleavedBlockScale);

// Interleave (and possibly pad) the weights block scaling factor.
// blockScale: [num_experts, rows, cols] or [rows, cols]
// Return: num_experts * pad_up(rows, 128) * pad_up(cols, 4)
Expand All @@ -148,7 +183,8 @@ void BlockScaleInterleave(TensorView blockScale, TensorView interleavedBlockScal
CHECK_CPU(blockScale);
}
CHECK_CONTIGUOUS(blockScale);
CHECK_INPUT_TYPE(blockScale, dl_uint8);
TVM_FFI_ICHECK(blockScale.dtype() == dl_uint8 || blockScale.dtype() == dl_bfloat16)
<< "Block Scale must be uint8 or bfloat16.";
auto blockScaleShape = blockScale.sizes();
TVM_FFI_ICHECK(blockScaleShape.size() == 2 || blockScaleShape.size() == 3)
<< "Block Scale should be 2D or 3D tensor.";
Expand All @@ -166,27 +202,28 @@ void BlockScaleInterleave(TensorView blockScale, TensorView interleavedBlockScal
const thread_local int smCount = tensorrt_llm::common::getMultiProcessorCount();
const cudaStream_t stream = get_stream(blockScale.device());

tensorrt_llm::kernels::invokeBlockScaleInterleave(
num_experts, rows, rows_padded, cols, cols_padded,
static_cast<uint8_t*>(blockScale.data_ptr()),
static_cast<uint8_t*>(interleavedBlockScale.data_ptr()), smCount, stream);
if (blockScale.dtype() == dl_uint8) {
tensorrt_llm::kernels::invokeBlockScaleInterleave(
num_experts, rows, rows_padded, cols, cols_padded,
static_cast<uint8_t*>(blockScale.data_ptr()),
static_cast<uint8_t*>(interleavedBlockScale.data_ptr()), smCount, stream);
} else if (blockScale.dtype() == dl_bfloat16) {
tensorrt_llm::kernels::invokeBlockScaleInterleave(
num_experts, rows, rows_padded, cols, cols_padded,
static_cast<__nv_bfloat16*>(blockScale.data_ptr()),
static_cast<__nv_bfloat16*>(interleavedBlockScale.data_ptr()), smCount, stream);
} else {
TVM_FFI_LOG_AND_THROW(NotImplementedError)
<< "block_scale_interleave only supports uint8 and bfloat16.";
}
} else {
for (int eIdx = 0; eIdx < static_cast<int>(num_experts); eIdx++) {
uint8_t* interleavedBlockScalePtr =
static_cast<uint8_t*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size;
for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx) {
auto globalRowIdx = eIdx * rows + rIdx;
uint8_t* blockScalePtr = static_cast<uint8_t*>(blockScale.data_ptr()) + globalRowIdx * cols;
for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) {
uint8_t sf_ori = 0;
if (rIdx < static_cast<int>(rows) && cIdx < static_cast<int>(cols)) {
sf_ori = blockScalePtr[cIdx];
}
int sf_index = computeSFIndex(rIdx, cIdx, rows, cols,
tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4);
interleavedBlockScalePtr[sf_index] = sf_ori;
}
}
if (blockScale.dtype() == dl_uint8) {
blockScaleInterleaveHost<uint8_t>(blockScale, interleavedBlockScale);
} else if (blockScale.dtype() == dl_bfloat16) {
blockScaleInterleaveHost<__nv_bfloat16>(blockScale, interleavedBlockScale);
} else {
TVM_FFI_LOG_AND_THROW(NotImplementedError)
<< "blockScaleInterleaveHost only supports uint8 and bfloat16.";
}
}
}
Expand Down
Loading