Skip to content

Commit b03cd94

Browse files
zoranjovanovic-nspemeliya
authored andcommitted
Rocm jaxlib v0.5.0 warpsize global (#177)
* cherry-picked warp size passing to triton calls, and globally enabled warpsize=64 * Fix. --------- Co-authored-by: Pavel Emeliyanenko <[email protected]> (cherry picked from commit f013645)
1 parent 9b74aba commit b03cd94

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

xla/backends/gpu/codegen/emitters/reduction.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class ReductionFusion : public EmitterBase {
121121
return IndexingMap::GetUndefined();
122122
}
123123

124-
int64_t WarpSize() const {
124+
virtual int64_t WarpSize() const {
125125
return ::xla::gpu::WarpSize(analysis_.device_info());
126126
}
127127

@@ -198,6 +198,11 @@ class ColumnReductionFusion : public ReductionFusion {
198198
public:
199199
explicit ColumnReductionFusion(const HloFusionAnalysis& analysis);
200200

201+
int64_t WarpSize() const override {
202+
// PAE HACK HACK
203+
return 32;
204+
}
205+
201206
protected:
202207
llvm::SmallVector<mlir::Value> EmitReduction(
203208
int group_id, EmitterState& state) const override;
@@ -216,6 +221,11 @@ class SmallColumnReductionFusion : public ReductionFusion {
216221
public:
217222
explicit SmallColumnReductionFusion(const HloFusionAnalysis& analysis);
218223

224+
int64_t WarpSize() const override {
225+
// PAE HACK HACK
226+
return 32;
227+
}
228+
219229
protected:
220230
llvm::SmallVector<mlir::Value> EmitReduction(
221231
int group_id, EmitterState& state) const override;

xla/backends/gpu/codegen/emitters/transpose.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,10 @@ using mlir::ValueRange;
7575
using mlir::func::FuncOp;
7676
using mlir::func::ReturnOp;
7777

78-
constexpr int kNumRows = 4;
79-
constexpr int kNumThreadsPerBlock = 128;
80-
constexpr int kMaxVectorizedBytes = 4;
78+
constexpr int kTileSize = 32;
79+
constexpr int kNumRows = 8;
80+
constexpr int kNumThreadsPerBlock = kNumRows * kTileSize;
81+
constexpr int kMaxVectorizedBytes = 16;
8182

8283
} // namespace
8384

@@ -87,7 +88,7 @@ TransposeFusion::TransposeFusion(const HloFusionAnalysis& analysis)
8788
permutation_(transpose_.permutation),
8889
input_shape_(
8990
Permute(transpose_.dimensions, InversePermutation(permutation_))),
90-
base_block_size_(WarpSize(analysis_.device_info())) {
91+
base_block_size_(kTileSize) {
9192
ConstHloInstructionSet transposes_to_tile;
9293
int index = 0;
9394
int64_t shmem_usage = 0;

xla/stream_executor/rocm/rocm_executor.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,7 @@ absl::StatusOr<int64_t> GetMaxRegistersPerBlock(hipDevice_t device) {
286286
}
287287

288288
absl::StatusOr<int64_t> GetThreadsPerWarp(hipDevice_t device) {
289-
// TODO(ROCm): This is almost certainly wrong but tests seem to rely on it.
290-
return 32;
289+
return GetSimpleAttribute<int64_t>(device, hipDeviceAttributeWarpSize);
291290
}
292291

293292
absl::Status GetGridLimits(int* x, int* y, int* z, hipDevice_t device) {

0 commit comments

Comments
 (0)