Skip to content

Commit 729dcdf

Browse files
[XLA:GPU] Rename warp to shmem_group in PackedTranspose (#434)
Also calculate their count as kNumThreadsPerBlock / kNumShmemBanks to avoid inconsistency when manually specified. This change is NFC for non-AMD GPUs. For AMD GPUs, it fixes the performance regression caused by inconsistency between shmem_group size, kNumThreadsPerBlock and kNumShmemBanks. It ended up in a situation downstream where half of the launched threads per block were not utilized at all. Updated packed transpose tests to verify correct thread utilization.
1 parent b3f5970 commit 729dcdf

File tree

6 files changed

+72
-49
lines changed

6 files changed

+72
-49
lines changed

xla/backends/gpu/codegen/emitters/tests/transpose/packed_transpose_bf16.hlo

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ fusion {
66
p0 = bf16[30,16,30] parameter(0)
77
ROOT transpose = bf16[30,16,30] transpose(p0), dimensions={2,1,0}
88
}
9-
// CHECK: xla_gpu.allocate_shared : tensor<64x64xbf16>
9+
// CHECK: #indexing_map{{.*}}domain:{{.*}}th_x in [0, [[N_THREADS:[0-9]+]]]
10+
// CHECK: %thread_id_x = gpu.thread_id x {xla.range = [0 : index, [[N_THREADS]] : index]}
11+
// CHECK: xla_gpu.allocate_shared : tensor<64x64xbf16>

xla/backends/gpu/codegen/emitters/tests/transpose/packed_transpose_f16.hlo

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ fusion {
66
p0 = f16[28,2,6,32] parameter(0)
77
ROOT transpose = f16[2,32,6,28] transpose(p0), dimensions={1,3,2,0}
88
}
9+
// CHECK: #indexing_map{{.*}}domain:{{.*}}th_x in [0, [[N_THREADS:[0-9]+]]]
10+
// CHECK: %thread_id_x = gpu.thread_id x {xla.range = [0 : index, [[N_THREADS]] : index]}
911
// CHECK: xla_gpu.allocate_shared : tensor<64x64xf16>

xla/backends/gpu/codegen/emitters/tests/transpose/packed_transpose_s4.hlo

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,6 @@ fusion {
77
ROOT %transpose= s4[128, 32, 8, 256] transpose(%param_0),
88
dimensions={0,3,2,1}
99
}
10+
// CHECK: #indexing_map{{.*}}domain:{{.*}}th_x in [0, [[N_THREADS:[0-9]+]]]
11+
// CHECK: %thread_id_x = gpu.thread_id x {xla.range = [0 : index, [[N_THREADS]] : index]}
1012
// CHECK: xla_gpu.allocate_shared : tensor<256x256xi4>

xla/backends/gpu/codegen/emitters/tests/transpose/packed_transpose_s8.hlo

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ fusion {
66
p0 = s8[8,64,68] parameter(0)
77
ROOT transpose = s8[8,68,64] transpose(p0), dimensions={0, 2, 1}
88
}
9-
// CHECK: xla_gpu.allocate_shared : tensor<128x128xi8>
9+
// CHECK: #indexing_map{{.*}}domain:{{.*}}th_x in [0, [[N_THREADS:[0-9]+]]]
10+
// CHECK: %thread_id_x = gpu.thread_id x {xla.range = [0 : index, [[N_THREADS]] : index]}
11+
// CHECK: xla_gpu.allocate_shared : tensor<128x128xi8>

xla/backends/gpu/codegen/emitters/transpose.cc

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ using mlir::VectorType;
9494
using mlir::func::FuncOp;
9595
using mlir::func::ReturnOp;
9696

97-
9897
namespace mt = ::mlir::tensor;
9998
namespace mv = ::mlir::vector;
10099

@@ -532,21 +531,21 @@ std::vector<int64_t> GetBlockCounts(absl::Span<const int64_t> shape,
532531
PackedTranspose::PackedTranspose(const HloFusionAnalysis& analysis,
533532
const TransposeSpec& spec,
534533
absl::Span<const int64_t> output_block_tile,
535-
int64_t num_warps)
534+
int64_t num_shmem_groups)
536535
: TransposeFusionBase(analysis),
537536
spec_(spec),
538537
output_tile_(output_block_tile.begin(), output_block_tile.end()),
539538
input_tile_(Permute(output_tile_, spec_.canonical_inv_permutation)),
540539
block_counts_(GetBlockCounts(spec_.canonical_output_shape, output_tile_)),
541-
num_warps_per_block_(num_warps),
540+
num_shmem_groups_per_block_(num_shmem_groups),
542541
tile_size_t1_(input_tile_[spec_.dim_T1_input_id()]),
543542
tile_size_a_(input_tile_[spec_.dim_A_id()]),
544543
tile_size_t2_(input_tile_[spec_.dim_T2_input_id()]),
545544
populated_shmem_cols_(tile_size_a_ * tile_size_t1_),
546545
populated_shmem_rows_(tile_size_t2_) {
547546
VLOG(5) << "Transpose spec: " << spec.ToString()
548547
<< "Output block tile: " << absl::StrJoin(output_block_tile, ", ")
549-
<< "\nNumber of warps: " << num_warps << "\n";
548+
<< "\nNumber of shmem groups: " << num_shmem_groups << "\n";
550549
auto bits_per_element = GetBitwidth(spec_.elem_type());
551550
vector_size_ = kBankBitwidth / bits_per_element;
552551
CHECK_GE(vector_size_, 1);
@@ -779,25 +778,27 @@ IndexingMap PackedTranspose::GetInputIndexing(MLIRContext* ctx) const {
779778
KernelFusionInterface::kIndexingMapThreadIdxDims[0], ctx);
780779
auto block_id =
781780
getAffineDimExpr(KernelFusionInterface::kIndexingMapBlockIdxDims[0], ctx);
782-
auto warp_size = kNumShmemBanks;
783-
auto lane_id = thread_id % warp_size;
784-
auto warp_id = thread_id.floorDiv(warp_size);
785-
std::vector<IndexingMap::Variable> dim_vars = DimVarsFromGPUGrid(
786-
{num_warps_per_block_ * warp_size, 1, 1, Product(block_counts_), 1, 1});
781+
auto shmem_group_size = kNumShmemBanks;
782+
auto lane_id = thread_id % shmem_group_size;
783+
auto shmem_group_id = thread_id.floorDiv(shmem_group_size);
784+
std::vector<IndexingMap::Variable> dim_vars =
785+
DimVarsFromGPUGrid({num_shmem_groups_per_block_ * shmem_group_size, 1, 1,
786+
Product(block_counts_), 1, 1});
787787

788788
// Range variables.
789789
auto loop = getAffineSymbolExpr(0, ctx);
790790
auto vector_element_id = getAffineSymbolExpr(1, ctx);
791791
std::vector<IndexingMap::Variable> range_vars = RangeVarsFromTensorSizes(
792-
{{CeilOfRatio(tile_size_t2_, num_warps_per_block_), vector_size_}});
792+
{{CeilOfRatio(tile_size_t2_, num_shmem_groups_per_block_),
793+
vector_size_}});
793794

794795
// Block offsets.
795796
auto block_ids = DelinearizeInBoundsIndex(block_id, block_counts_);
796797
absl::c_copy(Permute(block_ids, spec_.canonical_inv_permutation),
797798
block_ids.begin());
798799

799800
// Shmem expressions.
800-
auto shmem_row = loop * num_warps_per_block_ + warp_id;
801+
auto shmem_row = loop * num_shmem_groups_per_block_ + shmem_group_id;
801802
auto shmem_col = lane_id * vector_size_ + vector_element_id;
802803

803804
// Offsets within the block.
@@ -840,20 +841,21 @@ IndexingMap PackedTranspose::GetShmemWriteIndexing(
840841
// Dimensions variables.
841842
auto thread_id = getAffineDimExpr(
842843
KernelFusionInterface::kIndexingMapThreadIdxDims[0], ctx);
843-
auto warp_size = kNumShmemBanks;
844-
auto lane_id = thread_id % warp_size;
845-
auto warp_id = thread_id.floorDiv(warp_size);
846-
std::vector<IndexingMap::Variable> dim_vars = DimVarsFromGPUGrid(
847-
{num_warps_per_block_ * warp_size, 1, 1, Product(block_counts_), 1, 1});
844+
auto shmem_group_size = kNumShmemBanks;
845+
auto lane_id = thread_id % shmem_group_size;
846+
auto shmem_group_id = thread_id.floorDiv(shmem_group_size);
847+
std::vector<IndexingMap::Variable> dim_vars =
848+
DimVarsFromGPUGrid({num_shmem_groups_per_block_ * shmem_group_size, 1, 1,
849+
Product(block_counts_), 1, 1});
848850

849851
// Range variables.
850852
auto loop = getAffineSymbolExpr(0, ctx);
851853
auto vector_element_id = getAffineSymbolExpr(1, ctx);
852854
std::vector<IndexingMap::Variable> range_vars = RangeVarsFromTensorSizes(
853-
{CeilOfRatio(tile_size_t2_, num_warps_per_block_), vector_size_});
855+
{CeilOfRatio(tile_size_t2_, num_shmem_groups_per_block_), vector_size_});
854856

855857
// Shmem expressions.
856-
auto shmem_row = loop * num_warps_per_block_ + warp_id;
858+
auto shmem_row = loop * num_shmem_groups_per_block_ + shmem_group_id;
857859
auto shmem_col = lane_id * vector_size_ + vector_element_id;
858860
llvm::SmallVector<std::pair<AffineExpr, Interval>> constraints{
859861
{shmem_col, Interval{0, populated_shmem_cols_ - 1}},
@@ -872,25 +874,27 @@ IndexingMap PackedTranspose::GetShmemReadIndexing(
872874
// Dimensions variables.
873875
auto thread_id = getAffineDimExpr(
874876
KernelFusionInterface::kIndexingMapThreadIdxDims[0], ctx);
875-
auto warp_size = kNumShmemBanks;
876-
auto lane_id = thread_id % warp_size;
877-
auto warp_id = thread_id.floorDiv(warp_size);
878-
std::vector<IndexingMap::Variable> dim_vars = DimVarsFromGPUGrid(
879-
{num_warps_per_block_ * warp_size, 1, 1, Product(block_counts_), 1, 1});
877+
auto shmem_group_size = kNumShmemBanks;
878+
auto lane_id = thread_id % shmem_group_size;
879+
auto shmem_group_id = thread_id.floorDiv(shmem_group_size);
880+
std::vector<IndexingMap::Variable> dim_vars =
881+
DimVarsFromGPUGrid({num_shmem_groups_per_block_ * shmem_group_size, 1, 1,
882+
Product(block_counts_), 1, 1});
880883

881884
// Range variables.
882885
auto loop = getAffineSymbolExpr(0, ctx);
883886
auto vector_horizontal = getAffineSymbolExpr(1, ctx);
884887
auto vector_vertical = getAffineSymbolExpr(2, ctx);
885888
std::vector<IndexingMap::Variable> range_vars = RangeVarsFromTensorSizes(
886889
{CeilOfRatio(populated_shmem_cols_,
887-
(vector_size_ * num_warps_per_block_)),
890+
(vector_size_ * num_shmem_groups_per_block_)),
888891
vector_size_, vector_size_});
889892

890893
// Shmem expressions.
891894
auto shmem_row = lane_id * vector_size_ + vector_vertical;
892-
auto shmem_col = (loop * num_warps_per_block_ + warp_id) * vector_size_ +
893-
vector_horizontal;
895+
auto shmem_col =
896+
(loop * num_shmem_groups_per_block_ + shmem_group_id) * vector_size_ +
897+
vector_horizontal;
894898
llvm::SmallVector<std::pair<AffineExpr, Interval>> constraints{
895899
{shmem_col, Interval{0, populated_shmem_cols_ - 1}},
896900
{shmem_row, Interval{0, populated_shmem_rows_ - 1}}};
@@ -909,26 +913,29 @@ IndexingMap PackedTranspose::GetOutputIndexing(mlir::MLIRContext* ctx) const {
909913
KernelFusionInterface::kIndexingMapThreadIdxDims[0], ctx);
910914
auto block_id =
911915
getAffineDimExpr(KernelFusionInterface::kIndexingMapBlockIdxDims[0], ctx);
912-
auto warp_size = kNumShmemBanks;
913-
auto lane_id = thread_id % warp_size;
914-
auto warp_id = thread_id.floorDiv(warp_size);
915-
std::vector<IndexingMap::Variable> dim_vars = DimVarsFromGPUGrid(
916-
{num_warps_per_block_ * warp_size, 1, 1, Product(block_counts_), 1, 1});
916+
auto shmem_group_size = kNumShmemBanks;
917+
auto lane_id = thread_id % shmem_group_size;
918+
auto shmem_group_id = thread_id.floorDiv(shmem_group_size);
919+
std::vector<IndexingMap::Variable> dim_vars =
920+
DimVarsFromGPUGrid({num_shmem_groups_per_block_ * shmem_group_size, 1, 1,
921+
Product(block_counts_), 1, 1});
917922

918923
// Range variables.
919924
auto loop = getAffineSymbolExpr(0, ctx);
920925
auto vector_horizontal = getAffineSymbolExpr(1, ctx);
921926
auto vector_vertical = getAffineSymbolExpr(2, ctx);
922927
std::vector<IndexingMap::Variable> range_vars = RangeVarsFromTensorSizes(
923-
{CeilOfRatio(populated_shmem_cols_, vector_size_ * num_warps_per_block_),
928+
{CeilOfRatio(populated_shmem_cols_,
929+
vector_size_ * num_shmem_groups_per_block_),
924930
vector_size_, vector_size_});
925931

926932
// Block offsets.
927933
auto block_ids = DelinearizeInBoundsIndex(block_id, block_counts_);
928934

929935
// Shmem expressions.
930-
auto shmem_col = (loop * num_warps_per_block_ + warp_id) * vector_size_ +
931-
vector_horizontal;
936+
auto shmem_col =
937+
(loop * num_shmem_groups_per_block_ + shmem_group_id) * vector_size_ +
938+
vector_horizontal;
932939
auto shmem_row = lane_id * vector_size_ + vector_vertical;
933940

934941
// Offsets within the block.
@@ -972,7 +979,8 @@ std::unique_ptr<EmitterBase> CreateTransposeFusion(
972979
auto packed_transpose_tile = GetPackedTransposeTileSizes(spec);
973980
if (packed_transpose_tile.ok()) {
974981
return std::make_unique<PackedTranspose>(
975-
analysis, spec, *packed_transpose_tile, /* num_warps= */ 4);
982+
analysis, spec, *packed_transpose_tile,
983+
kNumThreadsPerBlock / kNumShmemBanks);
976984
}
977985
return std::make_unique<TransposeFusion>(analysis);
978986
}

xla/backends/gpu/codegen/emitters/transpose.h

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -196,19 +196,24 @@ class TransposeFusion : public TransposeFusionBase {
196196
// slice of shared memory.
197197
//
198198
// 5. Every GPU block gets a single 64 x 10 x 6 x bf16 tile.
199-
// The tile is read by `num_warps_per_block` warps.
200-
// Let's assume that there are 4 warps per block. In this case, on every
201-
// iteration each warp will read 10 x 6 x bf16 elements, i.e. every thread
202-
// (30 out of 32) performs a vector load of 2 x bf16 and stores it to the
203-
// shared memory. In total, there will be 16 iterations performed by each
204-
// block.
199+
// The tile is read by `num_shmem_groups_per_block` shmem groups.
200+
// Let's assume that there are 4 shmem groups per block. In this case, on
201+
// every iteration each shmem group will read 10 x 6 x bf16 elements, i.e.
202+
// every thread (30 out of 32) performs a vector load of 2 x bf16 and stores
203+
// it to the shared memory. In total, there will be 16 iterations performed
204+
// by each block.
205+
//
206+
// Note: When the hardware warp size equals kNumShmemBanks (32), then
207+
// num_shmem_groups_per_block equals the number of warps per block. This is
208+
// the case for NVIDIA GPUs, but not always for AMD GPUs where warp size
209+
// can differ (64).
205210
//
206211
// The following code snippet shows how the data is read from the input
207212
// tensor into the shared memory:
208213
//
209-
// for I = 0 to CEIL(shmem_rows, num_warps_per_block):
214+
// for I = 0 to CEIL(shmem_rows, num_shmem_groups_per_block):
210215
// for J = 0 to VECTOR_SIZE:
211-
// ROW = WARP_ID + NUM_WARPS * I
216+
// ROW = SHMEM_GROUP_ID + NUM_SHMEM_GROUPS * I
212217
// COL = LANE_ID * VECTOR_SIZE + J
213218
// SHMEM[ROW, COL] = INPUT[ROW, COL / 10, COL % 10]
214219
//
@@ -217,7 +222,7 @@ class TransposeFusion : public TransposeFusionBase {
217222
// 6. Each thread reads a VECTOR_SIZE x VECTOR_SIZE x bf16 tile from the shared
218223
// memory and performs the write of each of the columns of the tile.
219224
//
220-
// for I = 0 to CEIL(shmem_cols, VECTOR_SIZE * num_warps_per_block):
225+
// for I = 0 to CEIL(shmem_cols, VECTOR_SIZE * num_shmem_groups_per_block):
221226
// VECTOR_2D = arith.constant dense<0>
222227
// : vector<VECTOR_SIZE x VECTOR_SIZE x bf16>
223228
// for J = 0 to VECTOR_SIZE:
@@ -231,7 +236,7 @@ class PackedTranspose : public TransposeFusionBase {
231236
explicit PackedTranspose(const HloFusionAnalysis& analysis,
232237
const TransposeSpec& spec,
233238
absl::Span<const int64_t> output_block_tile,
234-
int64_t num_warps);
239+
int64_t num_shmem_groups);
235240

236241
LaunchDimensions launch_dimensions() const override;
237242

@@ -279,8 +284,10 @@ class PackedTranspose : public TransposeFusionBase {
279284
// Vector size in elements.
280285
int64_t vector_size_;
281286

282-
// Number of warps per block.
283-
int64_t num_warps_per_block_;
287+
// Number of shmem groups per block. Each shmem group consists of 32 threads
288+
// (kNumShmemBanks), chosen to match the number of shared memory banks for
289+
// optimal memory access patterns. This is independent of hardware warp size.
290+
int64_t num_shmem_groups_per_block_;
284291

285292
// Tile sizes for the canonicalical dimensions
286293
// [T2, A, T1, 1] -> [T1, A, T2, 1].

0 commit comments

Comments
 (0)