@@ -94,7 +94,6 @@ using mlir::VectorType;
9494using mlir::func::FuncOp;
9595using mlir::func::ReturnOp;
9696
97-
9897namespace mt = ::mlir::tensor;
9998namespace mv = ::mlir::vector;
10099
@@ -532,21 +531,21 @@ std::vector<int64_t> GetBlockCounts(absl::Span<const int64_t> shape,
532531PackedTranspose::PackedTranspose (const HloFusionAnalysis& analysis,
533532 const TransposeSpec& spec,
534533 absl::Span<const int64_t > output_block_tile,
535- int64_t num_warps )
534+ int64_t num_shmem_groups )
536535 : TransposeFusionBase(analysis),
537536 spec_(spec),
538537 output_tile_(output_block_tile.begin(), output_block_tile.end()),
539538 input_tile_(Permute(output_tile_, spec_.canonical_inv_permutation)),
540539 block_counts_(GetBlockCounts(spec_.canonical_output_shape, output_tile_)),
541- num_warps_per_block_(num_warps ),
540+ num_shmem_groups_per_block_(num_shmem_groups ),
542541 tile_size_t1_(input_tile_[spec_.dim_T1_input_id()]),
543542 tile_size_a_(input_tile_[spec_.dim_A_id()]),
544543 tile_size_t2_(input_tile_[spec_.dim_T2_input_id()]),
545544 populated_shmem_cols_(tile_size_a_ * tile_size_t1_),
546545 populated_shmem_rows_(tile_size_t2_) {
547546 VLOG (5 ) << " Transpose spec: " << spec.ToString ()
548547 << " Output block tile: " << absl::StrJoin (output_block_tile, " , " )
549- << " \n Number of warps : " << num_warps << " \n " ;
548+ << " \n Number of shmem groups : " << num_shmem_groups << " \n " ;
550549 auto bits_per_element = GetBitwidth (spec_.elem_type ());
551550 vector_size_ = kBankBitwidth / bits_per_element;
552551 CHECK_GE (vector_size_, 1 );
@@ -779,25 +778,27 @@ IndexingMap PackedTranspose::GetInputIndexing(MLIRContext* ctx) const {
779778 KernelFusionInterface::kIndexingMapThreadIdxDims [0 ], ctx);
780779 auto block_id =
781780 getAffineDimExpr (KernelFusionInterface::kIndexingMapBlockIdxDims [0 ], ctx);
782- auto warp_size = kNumShmemBanks ;
783- auto lane_id = thread_id % warp_size;
784- auto warp_id = thread_id.floorDiv (warp_size);
785- std::vector<IndexingMap::Variable> dim_vars = DimVarsFromGPUGrid (
786- {num_warps_per_block_ * warp_size, 1 , 1 , Product (block_counts_), 1 , 1 });
781+ auto shmem_group_size = kNumShmemBanks ;
782+ auto lane_id = thread_id % shmem_group_size;
783+ auto shmem_group_id = thread_id.floorDiv (shmem_group_size);
784+ std::vector<IndexingMap::Variable> dim_vars =
785+ DimVarsFromGPUGrid ({num_shmem_groups_per_block_ * shmem_group_size, 1 , 1 ,
786+ Product (block_counts_), 1 , 1 });
787787
788788 // Range variables.
789789 auto loop = getAffineSymbolExpr (0 , ctx);
790790 auto vector_element_id = getAffineSymbolExpr (1 , ctx);
791791 std::vector<IndexingMap::Variable> range_vars = RangeVarsFromTensorSizes (
792- {{CeilOfRatio (tile_size_t2_, num_warps_per_block_), vector_size_}});
792+ {{CeilOfRatio (tile_size_t2_, num_shmem_groups_per_block_),
793+ vector_size_}});
793794
794795 // Block offsets.
795796 auto block_ids = DelinearizeInBoundsIndex (block_id, block_counts_);
796797 absl::c_copy (Permute (block_ids, spec_.canonical_inv_permutation ),
797798 block_ids.begin ());
798799
799800 // Shmem expressions.
800- auto shmem_row = loop * num_warps_per_block_ + warp_id ;
801+ auto shmem_row = loop * num_shmem_groups_per_block_ + shmem_group_id ;
801802 auto shmem_col = lane_id * vector_size_ + vector_element_id;
802803
803804 // Offsets within the block.
@@ -840,20 +841,21 @@ IndexingMap PackedTranspose::GetShmemWriteIndexing(
840841 // Dimensions variables.
841842 auto thread_id = getAffineDimExpr (
842843 KernelFusionInterface::kIndexingMapThreadIdxDims [0 ], ctx);
843- auto warp_size = kNumShmemBanks ;
844- auto lane_id = thread_id % warp_size;
845- auto warp_id = thread_id.floorDiv (warp_size);
846- std::vector<IndexingMap::Variable> dim_vars = DimVarsFromGPUGrid (
847- {num_warps_per_block_ * warp_size, 1 , 1 , Product (block_counts_), 1 , 1 });
844+ auto shmem_group_size = kNumShmemBanks ;
845+ auto lane_id = thread_id % shmem_group_size;
846+ auto shmem_group_id = thread_id.floorDiv (shmem_group_size);
847+ std::vector<IndexingMap::Variable> dim_vars =
848+ DimVarsFromGPUGrid ({num_shmem_groups_per_block_ * shmem_group_size, 1 , 1 ,
849+ Product (block_counts_), 1 , 1 });
848850
849851 // Range variables.
850852 auto loop = getAffineSymbolExpr (0 , ctx);
851853 auto vector_element_id = getAffineSymbolExpr (1 , ctx);
852854 std::vector<IndexingMap::Variable> range_vars = RangeVarsFromTensorSizes (
853- {CeilOfRatio (tile_size_t2_, num_warps_per_block_ ), vector_size_});
855+ {CeilOfRatio (tile_size_t2_, num_shmem_groups_per_block_ ), vector_size_});
854856
855857 // Shmem expressions.
856- auto shmem_row = loop * num_warps_per_block_ + warp_id ;
858+ auto shmem_row = loop * num_shmem_groups_per_block_ + shmem_group_id ;
857859 auto shmem_col = lane_id * vector_size_ + vector_element_id;
858860 llvm::SmallVector<std::pair<AffineExpr, Interval>> constraints{
859861 {shmem_col, Interval{0 , populated_shmem_cols_ - 1 }},
@@ -872,25 +874,27 @@ IndexingMap PackedTranspose::GetShmemReadIndexing(
872874 // Dimensions variables.
873875 auto thread_id = getAffineDimExpr (
874876 KernelFusionInterface::kIndexingMapThreadIdxDims [0 ], ctx);
875- auto warp_size = kNumShmemBanks ;
876- auto lane_id = thread_id % warp_size;
877- auto warp_id = thread_id.floorDiv (warp_size);
878- std::vector<IndexingMap::Variable> dim_vars = DimVarsFromGPUGrid (
879- {num_warps_per_block_ * warp_size, 1 , 1 , Product (block_counts_), 1 , 1 });
877+ auto shmem_group_size = kNumShmemBanks ;
878+ auto lane_id = thread_id % shmem_group_size;
879+ auto shmem_group_id = thread_id.floorDiv (shmem_group_size);
880+ std::vector<IndexingMap::Variable> dim_vars =
881+ DimVarsFromGPUGrid ({num_shmem_groups_per_block_ * shmem_group_size, 1 , 1 ,
882+ Product (block_counts_), 1 , 1 });
880883
881884 // Range variables.
882885 auto loop = getAffineSymbolExpr (0 , ctx);
883886 auto vector_horizontal = getAffineSymbolExpr (1 , ctx);
884887 auto vector_vertical = getAffineSymbolExpr (2 , ctx);
885888 std::vector<IndexingMap::Variable> range_vars = RangeVarsFromTensorSizes (
886889 {CeilOfRatio (populated_shmem_cols_,
887- (vector_size_ * num_warps_per_block_ )),
890+ (vector_size_ * num_shmem_groups_per_block_ )),
888891 vector_size_, vector_size_});
889892
890893 // Shmem expressions.
891894 auto shmem_row = lane_id * vector_size_ + vector_vertical;
892- auto shmem_col = (loop * num_warps_per_block_ + warp_id) * vector_size_ +
893- vector_horizontal;
895+ auto shmem_col =
896+ (loop * num_shmem_groups_per_block_ + shmem_group_id) * vector_size_ +
897+ vector_horizontal;
894898 llvm::SmallVector<std::pair<AffineExpr, Interval>> constraints{
895899 {shmem_col, Interval{0 , populated_shmem_cols_ - 1 }},
896900 {shmem_row, Interval{0 , populated_shmem_rows_ - 1 }}};
@@ -909,26 +913,29 @@ IndexingMap PackedTranspose::GetOutputIndexing(mlir::MLIRContext* ctx) const {
909913 KernelFusionInterface::kIndexingMapThreadIdxDims [0 ], ctx);
910914 auto block_id =
911915 getAffineDimExpr (KernelFusionInterface::kIndexingMapBlockIdxDims [0 ], ctx);
912- auto warp_size = kNumShmemBanks ;
913- auto lane_id = thread_id % warp_size;
914- auto warp_id = thread_id.floorDiv (warp_size);
915- std::vector<IndexingMap::Variable> dim_vars = DimVarsFromGPUGrid (
916- {num_warps_per_block_ * warp_size, 1 , 1 , Product (block_counts_), 1 , 1 });
916+ auto shmem_group_size = kNumShmemBanks ;
917+ auto lane_id = thread_id % shmem_group_size;
918+ auto shmem_group_id = thread_id.floorDiv (shmem_group_size);
919+ std::vector<IndexingMap::Variable> dim_vars =
920+ DimVarsFromGPUGrid ({num_shmem_groups_per_block_ * shmem_group_size, 1 , 1 ,
921+ Product (block_counts_), 1 , 1 });
917922
918923 // Range variables.
919924 auto loop = getAffineSymbolExpr (0 , ctx);
920925 auto vector_horizontal = getAffineSymbolExpr (1 , ctx);
921926 auto vector_vertical = getAffineSymbolExpr (2 , ctx);
922927 std::vector<IndexingMap::Variable> range_vars = RangeVarsFromTensorSizes (
923- {CeilOfRatio (populated_shmem_cols_, vector_size_ * num_warps_per_block_),
928+ {CeilOfRatio (populated_shmem_cols_,
929+ vector_size_ * num_shmem_groups_per_block_),
924930 vector_size_, vector_size_});
925931
926932 // Block offsets.
927933 auto block_ids = DelinearizeInBoundsIndex (block_id, block_counts_);
928934
929935 // Shmem expressions.
930- auto shmem_col = (loop * num_warps_per_block_ + warp_id) * vector_size_ +
931- vector_horizontal;
936+ auto shmem_col =
937+ (loop * num_shmem_groups_per_block_ + shmem_group_id) * vector_size_ +
938+ vector_horizontal;
932939 auto shmem_row = lane_id * vector_size_ + vector_vertical;
933940
934941 // Offsets within the block.
@@ -972,7 +979,8 @@ std::unique_ptr<EmitterBase> CreateTransposeFusion(
972979 auto packed_transpose_tile = GetPackedTransposeTileSizes (spec);
973980 if (packed_transpose_tile.ok ()) {
974981 return std::make_unique<PackedTranspose>(
975- analysis, spec, *packed_transpose_tile, /* num_warps= */ 4 );
982+ analysis, spec, *packed_transpose_tile,
983+ kNumThreadsPerBlock / kNumShmemBanks );
976984 }
977985 return std::make_unique<TransposeFusion>(analysis);
978986}
0 commit comments