Skip to content

Commit 56bd5d2

Browse files
authored
fix flashmaskv2 maxmin buffer padding (#74881)
1 parent 472f376 commit 56bd5d2

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1190,7 +1190,7 @@ void FlashMaskV2GradBaseKernel(
11901190
// TODO(umiswing): refine this block constraint (kBlockN % 32), since some
11911191
// of kBlockN is not divisible by 32 flashmask_maxmin_shape[2] =
11921192
// (flashmask_maxmin_shape[2] + 31) / 32 * 8;
1193-
flashmask_maxmin_shape[2] = (flashmask_maxmin_shape[2] + 31) / 32;
1193+
flashmask_maxmin_shape[2] = ((flashmask_maxmin_shape[2] + 31) / 32 + 3) / 4 * 4;
11941194
flashmask_maxmin_shape[3] = 8;
11951195

11961196
flashmask_maxmin.set_type(phi::DataType::INT32);

paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1965,7 +1965,7 @@ void FlashMaskV2BaseKernel(
19651965
// TODO(umiswing): refine this block constraint (kBlockN % 32), since some
19661966
// of kBlockN is not divisible by 32 flashmask_maxmin_shape[2] =
19671967
// (flashmask_maxmin_shape[2] + 31) / 32 * 8;
1968-
flashmask_maxmin_shape[2] = ((flashmask_maxmin_shape[2] + 31) / 32 + 3) / 4;
1968+
flashmask_maxmin_shape[2] = ((flashmask_maxmin_shape[2] + 31) / 32 + 3) / 4 * 4;
19691969
flashmask_maxmin_shape[3] = 8;
19701970

19711971
flashmask_maxmin.set_type(phi::DataType::INT32);

0 commit comments

Comments
 (0)