@@ -68,8 +68,8 @@ template <RotaryMode rotary_mode, uint32_t vec_size, uint32_t bdx, uint32_t tile
6868__device__ __forceinline__ void compute_qk (const T* smem, uint32_t compute_stage_idx,
6969 const vec_t <float , vec_size>& q_vec,
7070 const vec_t <float , vec_size>& freq, uint32_t kv_idx_base,
71- uint32_t iter_base, uint32_t iter_bound, float sm_scale ,
72- float * s, state_t <vec_size>& st) {
71+ uint32_t iter_base, uint32_t iter_bound, float * s ,
72+ state_t <vec_size>& st) {
7373 uint32_t tx = threadIdx .x , tz = threadIdx .z ;
7474 float m_prev = st.m ;
7575#pragma unroll
@@ -86,7 +86,7 @@ __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage
8686 s[j] = 0 .f ;
8787#pragma unroll
8888 for (uint32_t i = 0 ; i < vec_size; ++i) {
89- s[j] += q_vec[i] * k_vec[i] * sm_scale ;
89+ s[j] += q_vec[i] * k_vec[i];
9090 }
9191#pragma unroll
9292 for (uint32_t offset = bdx / 2 ; offset > 0 ; offset /= 2 ) {
@@ -240,6 +240,11 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn*
240240 // do not apply rotary embedding to q matrix
241241 q_vec.cast_load (q + info.get_qo_elem_offset (0 , qo_head_idx, tx * vec_size));
242242 }
243+ // multiple q_vec by sm_scale
244+ #pragma unroll
245+ for (uint32_t i = 0 ; i < vec_size; ++i) {
246+ q_vec[i] *= sm_scale;
247+ }
243248 block.sync ();
244249
245250 uint32_t chunk_start = kv_chunk_idx * kv_chunk_size;
@@ -286,8 +291,8 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn*
286291 block.sync ();
287292 compute_qk<rotary_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
288293 k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec,
289- freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size, sm_scale ,
290- s, st_local);
294+ freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size, s ,
295+ st_local);
291296 block.sync ();
292297 // load k
293298 for (uint32_t j = 0 ; j < tile_size_per_bdx; ++j) {
@@ -385,6 +390,10 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(
385390 q_vec.cast_load (q + batch_idx * num_qo_heads * head_dim +
386391 info.get_qo_elem_offset (0 , qo_head_idx, tx * vec_size));
387392 }
393+ #pragma unroll
394+ for (uint32_t i = 0 ; i < vec_size; ++i) {
395+ q_vec[i] *= sm_scale;
396+ }
388397 block.sync ();
389398
390399 // preload k tiles and v tiles
@@ -421,7 +430,7 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(
421430 block.sync ();
422431 compute_qk<rotary_mode, vec_size, bdx, bdy>(k_smem + (stage_idx * bdz + tz) * bdy * head_dim,
423432 stage_idx, q_vec, freq, consumer_kv_idx_base,
424- iter * bdy * bdz, seq_len, sm_scale, s, st_local);
433+ iter * bdy * bdz, seq_len, s, st_local);
425434 block.sync ();
426435 // load k
427436 cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch , SharedMemFillMode::kNoFill >(
@@ -551,6 +560,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
551560 // do not apply rotary embedding to q matrix
552561 q_vec.cast_load (q + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
553562 }
563+ #pragma unroll
564+ for (uint32_t i = 0 ; i < vec_size; ++i) {
565+ q_vec[i] *= sm_scale;
566+ }
554567 block.sync ();
555568
556569 // preload k/v tiles
@@ -622,7 +635,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
622635 freq,
623636 (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset [mapped_batch_idx]) +
624637 cur_chunk_start + iter * tile_size_per_bdx * bdy * bdz,
625- iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, sm_scale, s, st);
638+ iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, s, st);
626639 block.sync ();
627640
628641#pragma unroll
0 commit comments