Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion xla/backends/gpu/codegen/emitters/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class ReductionFusion : public EmitterBase {
return IndexingMap::GetUndefined();
}

int64_t WarpSize() const {
virtual int64_t WarpSize() const {
return ::xla::gpu::WarpSize(analysis_.device_info());
}

Expand Down Expand Up @@ -198,6 +198,11 @@ class ColumnReductionFusion : public ReductionFusion {
public:
explicit ColumnReductionFusion(const HloFusionAnalysis& analysis);

int64_t WarpSize() const override {
// PAE HACK HACK
return 32;
}

protected:
llvm::SmallVector<mlir::Value> EmitReduction(
int group_id, EmitterState& state) const override;
Expand All @@ -216,6 +221,11 @@ class SmallColumnReductionFusion : public ReductionFusion {
public:
explicit SmallColumnReductionFusion(const HloFusionAnalysis& analysis);

int64_t WarpSize() const override {
// PAE HACK HACK
return 32;
}

protected:
llvm::SmallVector<mlir::Value> EmitReduction(
int group_id, EmitterState& state) const override;
Expand Down
9 changes: 5 additions & 4 deletions xla/backends/gpu/codegen/emitters/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ using mlir::ValueRange;
using mlir::func::FuncOp;
using mlir::func::ReturnOp;

constexpr int kNumRows = 4;
constexpr int kNumThreadsPerBlock = 128;
constexpr int kMaxVectorizedBytes = 4;
constexpr int kTileSize = 32;
constexpr int kNumRows = 8;
constexpr int kNumThreadsPerBlock = kNumRows * kTileSize;
constexpr int kMaxVectorizedBytes = 16;

} // namespace

Expand All @@ -87,7 +88,7 @@ TransposeFusion::TransposeFusion(const HloFusionAnalysis& analysis)
permutation_(transpose_.permutation),
input_shape_(
Permute(transpose_.dimensions, InversePermutation(permutation_))),
base_block_size_(WarpSize(analysis_.device_info())) {
base_block_size_(kTileSize) {
Comment thread
i-chaochen marked this conversation as resolved.
ConstHloInstructionSet transposes_to_tile;
int index = 0;
int64_t shmem_usage = 0;
Expand Down
3 changes: 1 addition & 2 deletions xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,7 @@ absl::StatusOr<int64_t> GetMaxRegistersPerBlock(hipDevice_t device) {
}

absl::StatusOr<int64_t> GetThreadsPerWarp(hipDevice_t device) {
// TODO(ROCm): This is almost certainly wrong but tests seem to rely on it.
return 32;
return GetSimpleAttribute<int64_t>(device, hipDeviceAttributeWarpSize);
}

absl::Status GetGridLimits(int* x, int* y, int* z, hipDevice_t device) {
Expand Down