diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu index 021ca31ec5..e4afcb3008 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu @@ -557,13 +557,13 @@ bool verify(const Options &options) { auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(), cute::make_layout( cute::make_shape(blockscale_m, blockscale_k, options.l), - cute::make_stride(blockscale_k, 1, blockscale_m * blockscale_k) + cute::make_stride(1, blockscale_m, blockscale_m * blockscale_k) ) ); auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(), cute::make_layout( cute::make_shape(blockscale_n, blockscale_k, options.l), - cute::make_stride(blockscale_k, 1, blockscale_n * blockscale_k) + cute::make_stride(1, blockscale_n, blockscale_n * blockscale_k) ) ); diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu index ccd0941d06..0394576490 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu @@ -396,14 +396,17 @@ template void initialize(const Options &options) { using TileShape = typename GroupScaleConfig::TileShape; - const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile; - const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile; + const int ScaleGranularityM = GroupScaleConfig::ScaleGranularityM; + const int ScaleGranularityN = GroupScaleConfig::ScaleGranularityN; + + assert(options.m % ScaleGranularityM == 0); + assert(options.n % ScaleGranularityN == 0); // Find Group Scaling tensor shapes based on `ScaleGranularityM`, problem shape, and TileShape auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k); auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{}))); - auto groupscale_m = cute::get<0>(blockscale_shape) * ScaleMsPerTile; // We need to pad along M in scale tensor of A to prevent illegal memory access. - auto groupscale_n = cute::get<1>(blockscale_shape) * ScaleNsPerTile; // We need to pad along N in scale tensor of A to prevent illegal memory access. + auto groupscale_m = cute::get<0>(gemm_problem_shape) / ScaleGranularityM; + auto groupscale_n = cute::get<1>(gemm_problem_shape) / ScaleGranularityN; auto blockscale_k = cute::get<2>(blockscale_shape); stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); @@ -575,6 +578,8 @@ bool verify(const Options &options, const int ScaleMsPerTile // // Compute reference output // + const int ScaleGranularityM = get<0>(TileShape_{}) / ScaleMsPerTile; + const int ScaleGranularityN = get<1>(TileShape_{}) / ScaleNsPerTile; // Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k); @@ -582,6 +587,8 @@ bool verify(const Options &options, const int ScaleMsPerTile auto blockscale_m = cute::get<0>(blockscale_shape); auto blockscale_n = cute::get<1>(blockscale_shape); auto blockscale_k = cute::get<2>(blockscale_shape); + auto groupscale_m = get<0>(gemm_problem_shape) / ScaleGranularityM; + auto groupscale_n = get<1>(gemm_problem_shape) / ScaleGranularityN; // Create instantiation for device reference gemm kernel auto A = cute::make_tensor(tensor_A.host_data(), @@ -617,14 +624,14 @@ bool verify(const Options &options, const int ScaleMsPerTile auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(), cute::make_layout( - cute::make_shape(blockscale_m, ScaleMsPerTile, blockscale_k, options.l), - cute::make_stride(blockscale_k * ScaleMsPerTile, 1, ScaleMsPerTile, blockscale_m * blockscale_k * ScaleMsPerTile) + cute::make_shape(groupscale_m, blockscale_k, options.l), + cute::make_stride(1, groupscale_m, groupscale_m * blockscale_k) ) ); auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(), cute::make_layout( - cute::make_shape(blockscale_n, ScaleNsPerTile, blockscale_k, options.l), - cute::make_stride(blockscale_k * ScaleNsPerTile, 1, ScaleNsPerTile, blockscale_n * blockscale_k * ScaleNsPerTile) + cute::make_shape(groupscale_n, blockscale_k, options.l), + cute::make_stride(1, groupscale_n, groupscale_n * blockscale_k) ) ); @@ -708,6 +715,31 @@ int run(Options &options) const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile; const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile; + bool skip = false; + + if (options.m % ScaleGranularityM != 0) { + std::cout << "Skippig (m size: " << options.m << " less then ScaleGranularityM: " << ScaleGranularityM << "):" << std::endl; + skip = true; + } + + if (options.n % ScaleGranularityN != 0) { + std::cout << "Skippig (n size: " << options.m << " less then ScaleGranularityN: " << ScaleGranularityM << "):" << std::endl; + skip = true; + } + + if (options.k % size<2>(TileShape{}) != 0) { + std::cout << "Skippig (k size: " << options.k << " less then TileShape[2]: " << size<2>(TileShape{}) << "):" << std::endl; + skip = true; + } + + if (!skip) std::cout << "Running: " << std::endl; + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl; + std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl; + std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl; + + if (skip) return -1; + initialize(options); // Instantiate CUTLASS kernel depending on templates @@ -768,10 +800,6 @@ int run(Options &options) raster = "Along M"; } - std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; - std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl; - std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl; - std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl; std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl; std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; std::cout << " GFLOPS: " << result.gflops << std::endl; diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h index e9809f6b2e..6bb593bda9 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h @@ -217,15 +217,19 @@ void gett_mainloop( } } - int64_t block_m = m / kBlockM; - int64_t block_n = n / kBlockN; - cute::Tensor blockscale_A = mainloop_params.ScaleA(block_m, _, _, l); - cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, _, l); - - const int ScaleGranularityM = cute::size<0>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleA.shape()); - const int ScaleGranularityN = cute::size<1>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleB.shape()); - assert(cute::size<0>(typename MainloopParams::TileShape{}) == ScaleGranularityM * cute::size<1>(mainloop_params.ScaleA.shape())); - assert(cute::size<1>(typename MainloopParams::TileShape{}) == ScaleGranularityN * cute::size<1>(mainloop_params.ScaleB.shape())); + const int M = cute::size<0>(mainloop_params.A.layout()); + const int N = cute::size<0>(mainloop_params.B.layout()); + const int ScaleGranularityM = M / cute::size<0>(mainloop_params.ScaleA); + const int ScaleGranularityN = N / cute::size<0>(mainloop_params.ScaleB); + assert(ScaleGranularityM && M % ScaleGranularityM == 0 + && "ScaleGranularityM must divide M"); + assert(ScaleGranularityN && N % ScaleGranularityN == 0 + && "ScaleGranularityN must divide N"); + + cute::Tensor blockscale_A = domain_offset( + make_coord(m / ScaleGranularityM, _0{}), mainloop_params.ScaleA(_, _, l)); + cute::Tensor blockscale_B = domain_offset( + make_coord(n / ScaleGranularityN, _0{}), mainloop_params.ScaleB(_, _, l)); // Compute on this k-block for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { @@ -257,9 +261,12 @@ void gett_mainloop( } } + int m_size = std::min(static_cast(kBlockM), cute::size<0>(mainloop_params.A.layout()) - m); + int n_size = std::min(static_cast(kBlockN), cute::size<0>(mainloop_params.B.layout()) - n); + // do compute - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { + for (int m_b = 0; m_b < m_size; ++m_b) { + for (int n_b = 0; n_b < n_size; ++n_b) { acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]); } } @@ -269,9 +276,9 @@ void gett_mainloop( // (b) Zero-out partial temporary (acc_temp), // (c) Update permanent (accu) if ((k+1) % kBlockK == 0) { - for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int m_b = 0; m_b < m_size; ++m_b) { auto scale_a_m_b = scale_a[m_b / ScaleGranularityM]; - for (int n_b = 0; n_b < kBlockN; ++n_b) { + for (int n_b = 0; n_b < n_size; ++n_b) { auto scale_b_n_b = scale_b[n_b / ScaleGranularityN]; ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a_m_b * scale_b_n_b; acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b]; diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index fa00c27e7e..e3e4060a93 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -118,8 +118,8 @@ struct CollectiveMma< using PipelineState = cutlass::PipelineState; using PipelineParams = typename MainloopPipeline::Params; - // Two threads per CTA are producers (1 for operand tile and 1 for scales) - static constexpr int NumProducerThreadEvents = 2; + // Two threads per CTA are producers (1 for operand tile `tma`, and 32 for scales `cp.async`) + static constexpr int NumProducerThreadEvents = 33; static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_; static constexpr int ScaleGranularityN = ScaleGranularityN_ == 0 ? size<1>(TileShape{}) : ScaleGranularityN_; @@ -150,10 +150,9 @@ struct CollectiveMma< cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // Block scaling gmem-to-smem copy atom - using BlockScaleCopyTypeA = cute::uint_byte_t(sizeof(ElementBlockScale)) * ScaleMsPerTile, 16)>; - using BlockScaleCopyTypeB = cute::uint_byte_t(sizeof(ElementBlockScale)) * ScaleNsPerTile, 16)>; - using SmemBlockScalingCopyAtomA = Copy_Atom, ElementBlockScale>; - using SmemBlockScalingCopyAtomB = Copy_Atom, ElementBlockScale>; + // we can have partial tiles in M or N, so don't vectorize those loads + using SmemBlockScalingCopyAtomA = Copy_Atom, ElementBlockScale>; + using SmemBlockScalingCopyAtomB = Copy_Atom, ElementBlockScale>; // Block scaling smem layout using SmemLayoutScaleA = Layout, Int>>; @@ -217,7 +216,6 @@ struct CollectiveMma< uint32_t tma_transaction_bytes = TmaTransactionBytes; uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; - uint32_t mma_promotion_interval = 4; // Block scaling factors for A and B ElementBlockScale const* ptr_scale_A; ElementBlockScale const* ptr_scale_B; @@ -263,7 +261,6 @@ struct CollectiveMma< transaction_bytes, transaction_bytes_mk, transaction_bytes_nk, - args.mma_promotion_interval, args.ptr_scale_A, args.ptr_scale_B }; @@ -283,11 +280,15 @@ struct CollectiveMma< implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + /* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA instructions. */ constexpr int pipe_k = size<2>(TileShape{}) / tile_size<2>(TiledMma{}); implementable = implementable && (args.mma_promotion_interval % 4 == 0) && (args.mma_promotion_interval == ScalePromotionInterval); implementable = implementable && (pipe_k % 4 == 0) && (pipe_k <= args.mma_promotion_interval); + // We expect full tiles in K + implementable = implementable && (K % size<2>(TileShape{}) == 0); + if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); } @@ -331,18 +332,18 @@ struct CollectiveMma< Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + auto tK = get<3>(gA_mkl.shape()); + // Make the tiled views of scale tensors - auto scaleA_shape = make_shape(shape<2>(gA_mkl), Int{}, shape<3>(gA_mkl), shape<4>(gA_mkl)); // (m,ScaleMsPerTile,k,l) - auto scaleB_shape = make_shape(shape<2>(gB_nkl), Int{}, shape<3>(gB_nkl), shape<4>(gB_nkl)); // (n,ScaleNsPerTile,k,l) - auto scale_dA = compact_order(scaleA_shape, Step<_2,_0,_1,_3>{}); - auto scale_dB = compact_order(scaleB_shape, Step<_2,_0,_1,_3>{}); - auto scaleA_layout = make_layout(scaleA_shape, scale_dA); - auto scaleB_layout = make_layout(scaleB_shape, scale_dB); + auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l) + auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{}); + auto scaleB_shape = make_shape(N / ScaleGranularityN, tK, L); // (scale_n,k,l) + auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_0, _1, _2>{}); // Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and // gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl. - Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (m,ScaleMsPerTile,k,l) - Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,ScaleNsPerTile,k,l) + Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l) + Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (scale_n,k,l) return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl); } @@ -367,103 +368,134 @@ struct CollectiveMma< TensorStorage& shared_tensors) { int lane_predicate = cute::elect_one_sync(); // Blockscaling: Tma loads for load_input and CpAsync for load_scale - if (lane_predicate) { - - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k) - Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (ScaleNsPerTile,k) - - // - // Prepare the TMA loads for A and B - // - - constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); - - auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); - - // Partition the inputs based on the current block coordinates. - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) - - - // Block scaling: load_scale has scaling tensors in global memory which are not tiled - Tensor mScaleA_mkl = get<2>(load_inputs); - Tensor mScaleB_nkl = get<3>(load_inputs); + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k) + Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (ScaleNsPerTile,k) - Tensor gScaleA = mScaleA_mkl(m_coord,_,_,l_coord); // (1,ScaleMsPerTile,k,1) - Tensor gScaleB = mScaleB_nkl(n_coord,_,_,l_coord); // (1,ScaleNsPerTile,k,1) - - TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, Layout>{}, Layout>>{}); // (1,ScaleMsPerTile,1) - TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, Layout>{}, Layout>>{}); // (1,ScaleNsPerTile,1) - ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x); - ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x); - - Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA); - Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA); - - Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB); - Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB); + // + // Prepare the TMA loads for A and B + // - // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + + // Block scaling: load_scale has scaling tensors in global memory which are not tiled + Tensor mScaleA_mkl = get<2>(load_inputs); + Tensor mScaleB_nkl = get<3>(load_inputs); + auto scales_m = get<0>(mScaleA_mkl.shape()); + auto scales_n = get<0>(mScaleB_nkl.shape()); + + Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape()); + Tensor cScaleB_nkl = make_identity_tensor(mScaleB_nkl.shape()); + + Tensor gScaleA = local_tile( + mScaleA_mkl, make_tile(Int{}), + make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1) + Tensor cScaleA = local_tile( + cScaleA_mkl, make_tile(Int{}), + make_coord(m_coord,_,l_coord)); + Tensor gScaleB = local_tile( + mScaleB_nkl, make_tile(Int{}), + make_coord(n_coord,_,l_coord)); // (ScaleNsPerTile,k,1) + Tensor cScaleB = local_tile( + cScaleB_nkl, make_tile(Int{}), + make_coord(n_coord,_,l_coord)); + + TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, + Layout>{}, Layout>{}); + TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, + Layout>{}, Layout>{}); + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x); + + Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA); + Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA); + Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA); + + Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB); + Tensor tBcB_ScaleB = thr_scale_copy_b.partition_S(cScaleB); + Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + Tensor tApA_ScaleA = make_tensor(shape(tAsA_ScaleA(_,_,0))); + Tensor tBpB_ScaleB = make_tensor(shape(tBsB_ScaleB(_,_,0))); + + #pragma unroll + for (int i = 0; i < size(tApA_ScaleA); ++i) { + tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < + std::min(scales_m, (m_coord + 1) * ScaleMsPerTile); + } - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + #pragma unroll + for (int i = 0; i < size(tBpB_ScaleB); ++i) { + tBpB_ScaleB(i) = get<0>(tBcB_ScaleB(i)) < + std::min(scales_n, (n_coord + 1) * ScaleNsPerTile); + } - uint16_t mcast_mask_a = 0; - uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; - // Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors - // Maps the tile -> block, value - if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int n = 0; n < size<1>(block_layout); ++n) { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); - } + // Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); } + } - if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int m = 0; m < size<0>(block_layout); ++m) { - mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); - } + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); } + } - // Mainloop - CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) { - // LOCK smem_pipe_write for _writing_ - pipeline.producer_acquire(smem_pipe_write); + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); - // - // Copy gmem to smem for *k_tile_iter - // - int write_stage = smem_pipe_write.index(); - using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + // + // Copy gmem to smem for *k_tile_iter + // + int write_stage = smem_pipe_write.index(); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); - // Copy operands A and B from global memory to shared memory - copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); - copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + // Copy operands A and B from global memory to shared memory + if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); - // Copy scale tensors from global memory to shared memory - copy(scale_copy_a, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage)); - copy(scale_copy_b, tBgB_ScaleB(_,_,*k_tile_iter), tBsB_ScaleB(_,_,write_stage)); - pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); + // Copy scale tensors from global memory to shared memory + copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage)); + copy_if(scale_copy_b, tBpB_ScaleB, tBgB_ScaleB(_,_,*k_tile_iter), tBsB_ScaleB(_,_,write_stage)); + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); - ++k_tile_iter; + ++k_tile_iter; - // Advance smem_pipe_write - ++smem_pipe_write; - } + // Advance smem_pipe_write + ++smem_pipe_write; } }