@@ -29,43 +29,62 @@ struct __align__(sizeof(T) * VecSize) VecType {
2929 }
3030};
3131
32- template <int VecSize>
33- __device__ void BlockLoad (const phi::bfloat16* input,
32+ template <typename InT, int VecSize>
33+ __device__ void BlockLoad (const InT* input,
34+ const float * input_scales,
3435 __nv_bfloat16 x[8 ][4 ],
35- size_t K) {
36+ size_t K,
37+ size_t k_scaled) {
38+ constexpr bool need_dequant = std::is_same_v<InT, phi::dtype::float8_e4m3fn>;
39+
40+ #pragma unroll
3641 for (uint32_t i = 0 ; i < 8 ; i++) {
37- size_t off_m = blockIdx .x * size_t (128 ) + threadIdx .y + i * 16 ;
38- size_t off_k = blockIdx .y * 128 + threadIdx .x * VecSize;
39- size_t offset = off_m * K + off_k;
42+ const uint32_t local_off_M = threadIdx .y + i * 16 ;
43+ const uint32_t off_m = blockIdx .x * 128 + local_off_M;
44+ const uint32_t off_k = blockIdx .y * 128 + threadIdx .x * VecSize;
45+ const size_t offset = off_m * K + off_k;
46+
47+ float scale;
48+ if constexpr (need_dequant) {
49+ const uint32_t m_base = blockIdx .x * 128 ;
50+ const uint32_t m_stride = k_scaled;
51+ scale = input_scales[off_m * m_stride + blockIdx .y ];
52+ }
4053
54+ #pragma unroll
4155 for (uint32_t j = 0 ; j < 4 ; j += VecSize) {
42- if (off_k + j * 32 < K) {
43- size_t idx = offset + j * 32 ;
44- using LoadT = VecType<__nv_bfloat16, VecSize>;
45- LoadT data = *reinterpret_cast <const LoadT*>(input + idx);
46- for (uint32_t k = 0 ; k < VecSize; k++) {
47- x[i][j + k] = data[k];
56+ const size_t idx = offset + j * 32 ;
57+ using LoadT = VecType<InT, VecSize>;
58+ LoadT data = *reinterpret_cast <const LoadT*>(input + idx);
59+ #pragma unroll
60+ for (uint32_t k = 0 ; k < VecSize; k++) {
61+ if constexpr (need_dequant) {
62+ x[i][j + k] = __float2bfloat16 (static_cast <float >(data[k]) * scale);
63+ } else {
64+ x[i][j + k] = (*reinterpret_cast <__nv_bfloat16*>(&data[k]));
4865 }
4966 }
5067 }
5168 }
5269}
53-
5470template <bool Pow2Scales>
5571__device__ void BlockColumnScale (const __nv_bfloat16 x[8 ][4 ],
56- float col_scale [128 ],
72+ float scales [128 ],
5773 __nv_bfloat16* shm) {
5874 // reduce [(8), 16, 32, 4] => [16, 32, 4]
5975 __nv_bfloat16 warp_max[4 ];
76+ #pragma unroll
6077 for (uint32_t i = 0 ; i < 8 ; i++) {
78+ #pragma unroll
6179 for (uint32_t j = 0 ; j < 4 ; j++) {
62- __nv_bfloat16 t = BF16_ABS (x[i][j]);
80+ const __nv_bfloat16 t = BF16_ABS (x[i][j]);
6381 warp_max[j] = i == 0 ? t : BF16_MAX (warp_max[j], t);
6482 }
6583 }
6684
6785 // reduce [(16), 32, 4] => [8, 32, 4]
6886 if (threadIdx .y >= 8 ) {
87+ #pragma unroll
6988 for (uint32_t j = 0 ; j < 4 ; j++) {
7089 shm[(threadIdx .y - 8 ) * 128 + threadIdx .x + j * 32 ] = warp_max[j];
7190 }
@@ -75,8 +94,9 @@ __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4],
7594 // reduce [(8), 32, 4] => [32, 4]
7695 for (uint32_t offset = 8 ; offset > 0 ; offset /= 2 ) {
7796 if (threadIdx .y < offset) {
97+ #pragma unroll
7898 for (uint32_t j = 0 ; j < 4 ; j++) {
79- __nv_bfloat16 other =
99+ const __nv_bfloat16 other =
80100 offset == 8
81101 ? warp_max[j]
82102 : shm[(threadIdx .y + offset) * 128 + threadIdx .x + j * 32 ];
@@ -85,7 +105,7 @@ __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4],
85105 if (offset > 1 ) {
86106 shm[threadIdx .y * 128 + threadIdx .x + j * 32 ] = next_val;
87107 } else {
88- col_scale [threadIdx .x + j * 32 ] =
108+ scales [threadIdx .x + j * 32 ] =
89109 ComputeScale<__nv_bfloat16, __nv_fp8_e4m3, Pow2Scales>(
90110 static_cast <float >(next_val), 0 .0f );
91111 }
@@ -98,7 +118,7 @@ __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4],
98118template <typename OutT, int VecSize>
99119__device__ void BlockStoreScale (float * scale,
100120 size_t off_m,
101- float col_scale [128 ],
121+ float scales [128 ],
102122 size_t K) {
103123 if (threadIdx .y < 4 ) {
104124 uint32_t off = threadIdx .y * 32 + threadIdx .x ;
@@ -107,10 +127,10 @@ __device__ void BlockStoreScale(float* scale,
107127 } else if constexpr (VecSize == 2 ) {
108128 off = (off / 64 ) * 64 + (off % 2 ) * 32 + (off % 64 ) / 2 ;
109129 }
110- float scale_out = 1 .0f / col_scale [off];
111- size_t idx_y = blockIdx .x - off_m / 128 ;
112- size_t idx_x = blockIdx .y * 128 + threadIdx .y * 32 + threadIdx .x ;
113- size_t idx = idx_y * K + idx_x;
130+ float scale_out = 1 .0f / scales [off];
131+ const size_t idx_y = blockIdx .x - off_m / 128 ;
132+ const size_t idx_x = blockIdx .y * 128 + threadIdx .y * 32 + threadIdx .x ;
133+ const size_t idx = idx_y * K + idx_x;
114134 if (idx_x < K) {
115135 scale[idx] = scale_out;
116136 }
@@ -123,14 +143,16 @@ __device__ void BlockStoreOut(OutT* out,
123143 size_t cur_tokens,
124144 const OutT shm[128 ][129 ],
125145 size_t K) {
146+ #pragma unroll
126147 for (uint32_t i = 0 ; i < 8 ; i++) {
127- size_t idx_m = blockIdx .x * size_t (128 ) + threadIdx .x * 4 ;
128- size_t idx_k = blockIdx .y * 128 + threadIdx .y + i * 16 ;
129- size_t idx = idx_k * cur_tokens + (idx_m - off_m);
148+ const size_t idx_m = blockIdx .x * size_t (128 ) + threadIdx .x * 4 ;
149+ const size_t idx_k = blockIdx .y * 128 + threadIdx .y + i * 16 ;
150+ const size_t idx = idx_k * cur_tokens + (idx_m - off_m);
130151
131152 if (idx_k < K) {
132153 using StoreT = VecType<OutT, VecSize>;
133154 StoreT data;
155+ #pragma unroll
134156 for (uint32_t j = 0 ; j < VecSize; j++) {
135157 data[j] = shm[i * 16 + threadIdx .y ][threadIdx .x * 4 + j];
136158 }
@@ -139,23 +161,27 @@ __device__ void BlockStoreOut(OutT* out,
139161 }
140162}
141163
142- template <typename OutT, bool Pow2Scales, int VecSize>
164+ template <typename InT, typename OutT, bool Pow2Scales, int VecSize>
143165__global__ void __launch_bounds__ (512 )
144- FusedTransposeSplitQuantKernel(const phi::bfloat16* __restrict__ input,
166+ FusedTransposeSplitQuantKernel(const InT* __restrict__ input,
167+ const float * __restrict__ input_scales,
145168 int64_t * __restrict__ meta,
146169 size_t num_experts,
147- size_t K) {
170+ size_t K,
171+ size_t k_scaled) {
148172 __shared__ OutT shm[128 ][129 ];
173+ __shared__ size_t expert_info[2 ];
174+ __shared__ float scales[128 ]; // May be reused? Is it worthy?
175+
149176 int64_t * tokens_per_expert = meta;
150177 OutT** out_ptrs = reinterpret_cast <OutT**>(meta + num_experts);
151178 float ** scale_ptrs = reinterpret_cast <float **>(meta + num_experts * 2 );
152179
153180 // 1. Load 128x128 elements from input
154181 __nv_bfloat16 x[8 ][4 ];
155- BlockLoad<VecSize>(input, x, K);
182+ BlockLoad<InT, VecSize>(input, input_scales, x, K, k_scaled );
156183
157184 // 2. Get expert index and offset of the current block
158- __shared__ size_t expert_info[2 ];
159185 if (threadIdx .x == 0 && threadIdx .y == 0 ) {
160186 size_t idx_m = blockIdx .x * size_t (128 );
161187 size_t off_m = 0 , next_off_m = 0 ;
@@ -172,21 +198,23 @@ __global__ void __launch_bounds__(512)
172198 }
173199
174200 // 3. Calculate scale along the column
175- __shared__ float col_scale[128 ];
176201 BlockColumnScale<Pow2Scales>(
177- x, col_scale , reinterpret_cast <__nv_bfloat16*>(shm));
202+ x, scales , reinterpret_cast <__nv_bfloat16*>(shm));
178203
179204 // 4. Store scale
180205 const size_t expert_idx = expert_info[0 ];
181206 const size_t off_m = expert_info[1 ];
182- BlockStoreScale<OutT, VecSize>(scale_ptrs[expert_idx], off_m, col_scale , K);
207+ BlockStoreScale<OutT, VecSize>(scale_ptrs[expert_idx], off_m, scales , K);
183208
184- // 5. Scale x and save into shared memory with transposed layout
209+ // 5. Scale x and save into shared memory with transposed layout
210+ #pragma unroll
185211 for (uint32_t i = 0 ; i < 8 ; i++) {
212+ #pragma unroll
186213 for (uint32_t j = 0 ; j < 4 ; j += VecSize) {
214+ #pragma unroll
187215 for (uint32_t k = 0 ; k < VecSize; k++) {
188216 float x_fp32 = static_cast <float >(x[i][j + k]);
189- float x_scaled = x_fp32 * col_scale [threadIdx .x + (j + k) * 32 ];
217+ float x_scaled = x_fp32 * scales [threadIdx .x + (j + k) * 32 ];
190218 shm[threadIdx .x * VecSize + j * 32 + k][i * 16 + threadIdx .y ] =
191219 static_cast <OutT>(x_scaled);
192220 }
@@ -204,10 +232,11 @@ template <typename T, typename Context>
204232void FusedTransposeSplitQuantKernel (
205233 const Context& dev_ctx,
206234 const DenseTensor& x,
235+ const paddle::optional<DenseTensor>& input_scales,
207236 const std::vector<int64_t >& tokens_per_expert,
208237 bool pow_2_scales,
209238 std::vector<DenseTensor*> outs,
210- std::vector<DenseTensor*> scales ) {
239+ std::vector<DenseTensor*> output_scales ) {
211240 auto x_dims = x.dims ();
212241 const int64_t M = x_dims[0 ];
213242 const int64_t K = x_dims[1 ];
@@ -221,8 +250,8 @@ void FusedTransposeSplitQuantKernel(
221250 if (outs[i] != nullptr ) {
222251 dev_ctx.template Alloc <phi::dtype::float8_e4m3fn>(outs[i]);
223252 }
224- if (scales [i] != nullptr ) {
225- dev_ctx.template Alloc <float >(scales [i]);
253+ if (output_scales [i] != nullptr ) {
254+ dev_ctx.template Alloc <float >(output_scales [i]);
226255 }
227256 }
228257
@@ -245,8 +274,8 @@ void FusedTransposeSplitQuantKernel(
245274
246275 for (size_t i = 0 ; i < num_experts; i++) {
247276 meta_ptr[num_experts * 2 + i] =
248- scales [i] != nullptr
249- ? reinterpret_cast <int64_t >(scales [i]->data <float >())
277+ output_scales [i] != nullptr
278+ ? reinterpret_cast <int64_t >(output_scales [i]->data <float >())
250279 : 0 ;
251280 }
252281
@@ -255,23 +284,35 @@ void FusedTransposeSplitQuantKernel(
255284
256285 auto stream = dev_ctx.stream ();
257286
258- dim3 grid (M / 128 , (K + 127 ) / 128 );
287+ // pre-compute on CPU to reduce size_t division cost in kernel
288+ const size_t k_scaled = (K + 127 ) / 128 ;
289+ dim3 grid (M / 128 , k_scaled);
259290 dim3 block (32 , 16 );
260291
261- #define LAUNCH_KERNEL (POW_2_SCALES, VEC_SIZE ) \
262- FusedTransposeSplitQuantKernel<phi::dtype::float8_e4m3fn, \
263- POW_2_SCALES, \
264- VEC_SIZE> \
265- <<<grid, block, 0 , stream>>> (x.data <phi::dtype::bfloat16>(), \
266- meta_gpu.data <int64_t >(), \
267- num_experts, \
268- K);
292+ #define DTYPE_CASE (dtype, type ) dtype == phi::DataType::type
293+ #define LAUNCH_KERNEL (T, POW_2_SCALES, VEC_SIZE ) \
294+ FusedTransposeSplitQuantKernel<T, \
295+ phi::dtype::float8_e4m3fn, \
296+ POW_2_SCALES, \
297+ VEC_SIZE><<<grid, block, 0 , stream>>> ( \
298+ x.data <T>(), \
299+ input_scales ? input_scales.get_ptr ()->data <float >() : nullptr , \
300+ meta_gpu.data <int64_t >(), \
301+ num_experts, \
302+ K, \
303+ k_scaled);
304+ #define DISPATCH_DATATYPE (POW_2_SCALES, VEC_SIZE ) \
305+ if (DTYPE_CASE (x.dtype (), BFLOAT16)) { \
306+ LAUNCH_KERNEL (phi::bfloat16, POW_2_SCALES, VEC_SIZE); \
307+ } else if (DTYPE_CASE (x.dtype (), FLOAT8_E4M3FN)) { \
308+ LAUNCH_KERNEL (phi::float8_e4m3fn, POW_2_SCALES, VEC_SIZE); \
309+ }
269310
270311#define LAUNCH_KERNEL_PARTIAL (VEC_SIZE ) \
271312 if (pow_2_scales) { \
272- LAUNCH_KERNEL (true , VEC_SIZE); \
313+ DISPATCH_DATATYPE (true , VEC_SIZE); \
273314 } else { \
274- LAUNCH_KERNEL (false , VEC_SIZE); \
315+ DISPATCH_DATATYPE (false , VEC_SIZE); \
275316 }
276317
277318 if (K % 4 == 0 ) {
@@ -296,7 +337,8 @@ PD_REGISTER_KERNEL(fused_transpose_split_quant,
296337 double ,
297338 int ,
298339 int64_t ,
299- phi::dtype::bfloat16) {
340+ phi::dtype::bfloat16,
341+ phi::dtype::float8_e4m3fn) {
300342 kernel->OutputAt (0 ).SetDataType (phi::DataType::FLOAT8_E4M3FN);
301343 kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32);
302344}
0 commit comments