GMMA::ScaleOut::Zero Not Equivalent to clear() ?
#2284
-
|
I'm trying to develop a GEMM example utilizing TMA and WGMMA on NVIDIA Hopper GPUs. In the CUTLASS examples, the accumulator is initialized using: This sets the However, in my implementation, this approach doesn't work as expected. For example, if Here's a simplified snippet of the code. // ...
// Allocate the accumulators
Tensor accum = partition_fragment_C(tiled_mma, take<0, 2>(TileShapeMNK{})); // (MMA,MMA_M,MMA_N)
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; // Init mma accumulate_
CUTE_NO_UNROLL
for (int k_tile_idx = 0; k_tile_idx < k_tile_count; ++k_tile_idx) {
// ... copy A, B and sync
warpgroup_fence_operand(accum);
warpgroup_arrive();
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA, tCrB, accum);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch(); // wgmma.commit_group
warpgroup_wait<0>(); // wgmma.wait_group, Wait for all MMAs in a K_TILE to complete
warpgroup_fence_operand(accum);
}Conversely, explicitly invoking // ...
// Allocate the accumulators
Tensor accum = partition_fragment_C(tiled_mma, take<0, 2>(TileShapeMNK{})); // (MMA,MMA_M,MMA_N)
clear(accum);
CUTE_NO_UNROLL
for (int k_tile_idx = 0; k_tile_idx < k_tile_count; ++k_tile_idx) {
// ... copy A, B and sync
warpgroup_fence_operand(accum);
warpgroup_arrive();
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA, tCrB, accum);
warpgroup_commit_batch(); // wgmma.commit_group
warpgroup_wait<0>(); // wgmma.wait_group, Wait for all MMAs in a K_TILE to complete
warpgroup_fence_operand(accum);
}I would appreciate any insights into why For reference, I'm utilizing the |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
|
You're setting it to zero for the entire first k tile, which is more than one MMA. Have to do it only for the first k block only. See cutlass mainloops that unroll the K iteration of cute::GEMM to be able to set scale value to one after the first mma |
Beta Was this translation helpful? Give feedback.
You're setting it to zero for the entire first k tile, which is more than one MMA. Have to do it only for the first k block only. See cutlass mainloops that unroll the K iteration of cute::GEMM to be able to set scale value to one after the first mma