Skip to content

Commit dace1ec

Browse files
alekstheodzahiqbal
authored andcommitted
Ported all triton related changes from v0.6.0 to v0.7.1
(cherry picked from commit 1851bcc) Disable softmax triton fusion if triton gemm is off (#281) * Disable softmax rewriter triton if triton gemm is disabled * Add specific flag to enable triton softmax fusion * Address review comments (cherry picked from commit 51a7f4b) [ROCm][Triton] Disable transposed load in certain conditions (cherry picked from commit 50860e9) Enable unit tests that pass after fixing some Triton related issues. (#285) * Enable unit tests that pass after fixing some Triton related issues. * fusion_emitter_device_legacy_test still fails on MI200 (cherry picked from commit 97dd565) Rocm jaxlib v0.6.0 triton support ut (#279) * Fixed triton/support_test - no fmfa. * Fix issue with rounding mode in accelerate amd matmul. * Fixed issues with usage of mfma in support_test. (cherry picked from commit 44f7d87) Restore gpu_triton_custom_call_test (#262) (cherry picked from commit 32eafa4) Skipped CanNotEmitTritonCustomCallOnPreAmpereGpu test for ROCM. (cherry picked from commit 56ec7ec) (cherry picked from commit b1f3e9f) fixed createTritonAMDGPULowerInstructionSchedHintsPass (#179) (cherry picked from commit 8517a3a) (cherry picked from commit c62e47d) fixed bazel build issue
1 parent c311726 commit dace1ec

File tree

13 files changed

+265
-29
lines changed

13 files changed

+265
-29
lines changed

build_tools/rocm/run_xla.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ GPU_NAME=(`rocminfo | grep -m 1 gfx`)
4545
GPU_NAME=${GPU_NAME[1]}
4646

4747
EXCLUDED_TESTS=(
48+
# //xla/service/gpu/tests:gpu_kernel_tiling_test_gpu_amd_any
49+
GpuKernelTilingTest.ColumnReductionWithLayoutChangeTiled
50+
GpuKernelTilingTest.ReductionInputTooLarge
4851
# //xla/pjrt/c:pjrt_c_api_gpu_test_gpu_amd_any
4952
PjrtCAPIGpuExtensionTest.TritonCompile
5053
# //xla/backends/gpu/codegen/triton:fusion_emitter_device_test_gpu_amd_any
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
From d539916e4d49cca93f54a5f99f7822050205432c Mon Sep 17 00:00:00 2001
2+
From: Jungwook Park <[email protected]>
3+
Date: Thu, 7 Aug 2025 06:34:49 -0500
4+
Subject: [PATCH] [AMD] Quick fix disabling transposed load used as different
5+
type.
6+
7+
Disabling transposedLoad if dot is using it as a different element type.
8+
Otherwise it's picking wrong vectorsize when lowering.
9+
---
10+
.../lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp | 20 +++++++++++++++++++
11+
1 file changed, 20 insertions(+)
12+
13+
diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp
14+
index 661a17678..6bda3a818 100644
15+
--- a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp
16+
+++ b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp
17+
@@ -214,6 +214,26 @@ private:
18+
return false;
19+
}
20+
21+
+ // transposed load can be used only when it's consumed by dot with the
22+
+ // loaded data type.
23+
+ int opIdx = 0;
24+
+ triton::gpu::LocalLoadOp lLoad = cast<triton::gpu::LocalLoadOp>(localLoad);
25+
+ if (auto dotEnc = lLoad.getSrc().getType().getEncoding())
26+
+ opIdx = cast<triton::gpu::DotOperandEncodingAttr>(dotEnc).getOpIdx();
27+
+ else
28+
+ return false;
29+
+
30+
+ SetVector<Operation *> slice;
31+
+ getForwardSlice(localLoad, &slice);
32+
+ for (auto op : slice) {
33+
+ if (auto dotOp = dyn_cast<triton::DotOp>(op)) {
34+
+ auto inputMat = (opIdx == 0) ? dotOp.getA() : dotOp.getB();
35+
+ auto bitwidthMat = inputMat.getType().getElementTypeBitWidth();
36+
+ if (bitwidth != bitwidthMat)
37+
+ return false;
38+
+ }
39+
+ }
40+
+
41+
return true;
42+
}
43+
44+
--
45+
2.34.1
46+
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
--- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp
2+
+++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp
3+
@@ -1068,7 +1068,10 @@ public:
4+
if (isFloat(srcElTy) && isFloat(dstElTy)) {
5+
auto rmode =
6+
RoundingModeAttr::get(rewriter.getContext(), RoundingMode::RTNE);
7+
- return rewriter.create<FpToFpOp>(loc, dstTy, v, rmode);
8+
+ if (dstElTy.getIntOrFloatBitWidth() < srcElTy.getIntOrFloatBitWidth()) {
9+
+ return rewriter.create<FpToFpOp>(loc, dstTy, v, rmode);
10+
+ }
11+
+ return rewriter.create<FpToFpOp>(loc, dstTy, v);
12+
}
13+
if (!isFloat(srcElTy) && isFloat(dstElTy))
14+
return rewriter.create<arith::SIToFPOp>(loc, dstTy, v);
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
--- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp
2+
+++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp
3+
@@ -380,9 +380,17 @@ Value convertAndCastTensor(PatternRewriter &rewriter, Value value,
4+
else if (oldElemType.isF32() && newElemType.isF16())
5+
castedTensor =
6+
rewriter.create<arith::TruncFOp>(loc, castedType, convertedTensor);
7+
- else
8+
- castedTensor =
9+
- rewriter.create<tt::FpToFpOp>(loc, castedType, convertedTensor);
10+
+ else {
11+
+ if(oldElemType.getIntOrFloatBitWidth() > newElemType.getIntOrFloatBitWidth()) {
12+
+ auto rmode =
13+
+ RoundingModeAttr::get(rewriter.getContext(), RoundingMode::RTNE);
14+
+ castedTensor =
15+
+ rewriter.create<tt::FpToFpOp>(loc, castedType, convertedTensor, rmode);
16+
+ } else {
17+
+ castedTensor =
18+
+ rewriter.create<tt::FpToFpOp>(loc, castedType, convertedTensor);
19+
+ }
20+
+ }
21+
}
22+
return castedTensor;
23+
}
24+

third_party/triton/temporary/series.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,8 @@ temporary_patch_list = [
1919
"//third_party/triton:temporary/tutorial_fixes.patch",
2020
"//third_party/triton:temporary/ws_fix.patch",
2121
"//third_party/triton:temporary/ws_ub_fix.patch",
22+
"//third_party/triton:temporary/0001-AMD-Quick-fix-disabling-transposed-load-used-as-diff.patch",
23+
"//third_party/triton:temporary/accelerateamdmatmul.patch",
24+
"//third_party/triton:temporary/accelerateamdmatmul2.patch",
2225
# Add new patches just above this line
2326
]

xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,8 @@ absl::Status CreateTritonPipeline(
146146
pm->addPass(mlir::createCanonicalizerPass());
147147
pm->addPass(mlir::createCSEPass());
148148
pm->addPass(mlir::createSymbolDCEPass());
149-
if (/*(instruction_sched_variant=="none") == */ false) {
150-
pm->addPass(mt::createTritonAMDGPULowerInstructionSchedHintsPass(
151-
cc.gfx_version(), num_stages));
152-
}
149+
pm->addPass(mt::createTritonAMDGPULowerInstructionSchedHintsPass(
150+
cc.gfx_version(), num_stages));
153151
pm->addPass(mt::createConvertBuiltinFuncToLLVMPass(/*ftz=*/true));
154152
// There is no clusters in ROCm for now.
155153
out_cluster_info.clusterDimX = 1;

xla/backends/gpu/codegen/triton/emitter_helpers.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ absl::StatusOr<Type> TritonType(EmitterLocOpBuilder& b, PrimitiveType t) {
106106
return b.getType<mlir::Float8E5M2Type>();
107107
case F8E4M3FN:
108108
return b.getType<mlir::Float8E4M3FNType>();
109+
case F8E4M3B11FNUZ:
110+
return b.getType<mlir::Float8E4M3B11FNUZType>();
111+
case F8E5M2FNUZ:
112+
return b.getType<mlir::Float8E5M2FNUZType>();
113+
case F8E4M3FNUZ:
114+
return b.getType<mlir::Float8E4M3FNUZType>();
109115
default:
110116
return absl::UnimplementedError(
111117
absl::StrCat("This type is not supported yet: ",
@@ -126,6 +132,9 @@ absl::StatusOr<PrimitiveType> GetPrimitiveType(Type t) {
126132
if (t.isInteger(1)) return PRED;
127133
if (mlir::isa<mlir::Float8E5M2Type>(t)) return F8E5M2;
128134
if (mlir::isa<mlir::Float8E4M3FNType>(t)) return F8E4M3FN;
135+
if (mlir::isa<mlir::Float8E4M3B11FNUZType>(t)) return F8E4M3B11FNUZ;
136+
if (mlir::isa<mlir::Float8E5M2FNUZType>(t)) return F8E5M2FNUZ;
137+
if (mlir::isa<mlir::Float8E4M3FNUZType>(t)) return F8E4M3FNUZ;
129138
return absl::UnimplementedError("Unsupported type in getPrimitiveType.\n");
130139
}
131140

xla/backends/gpu/codegen/triton/support.cc

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,12 @@ bool IsTritonSupportedDataType(PrimitiveType type,
6262
case F64:
6363
return true;
6464
case F8E5M2:
65-
case F8E4M3FN:
66-
return std::holds_alternative<se::CudaComputeCapability>(gpu_version);
65+
case F8E4M3FN:
66+
return std::holds_alternative<se::CudaComputeCapability>(gpu_version) ||
67+
std::holds_alternative<se::RocmComputeCapability>(gpu_version);
68+
case F8E5M2FNUZ:
69+
case F8E4M3FNUZ:
70+
return std::holds_alternative<se::RocmComputeCapability>(gpu_version);
6771
case BF16:
6872
return std::holds_alternative<se::CudaComputeCapability>(gpu_version) ||
6973
(std::holds_alternative<se::RocmComputeCapability>(gpu_version) &&
@@ -92,7 +96,10 @@ absl::flat_hash_set<HloOpcode> TritonSupportedUnaryElementwiseOps(
9296
absl::flat_hash_set<HloOpcode> ret{HloOpcode::kAbs, HloOpcode::kCopy};
9397

9498
if (element_type != PrimitiveType::F8E5M2 &&
95-
element_type != PrimitiveType::F8E4M3FN) {
99+
element_type != PrimitiveType::F8E4M3FN &&
100+
element_type != PrimitiveType::F8E4M3B11FNUZ &&
101+
element_type != PrimitiveType::F8E5M2FNUZ &&
102+
element_type != PrimitiveType::F8E4M3FNUZ) {
96103
ret.insert(HloOpcode::kNegate);
97104
}
98105

@@ -171,7 +178,10 @@ absl::flat_hash_set<HloOpcode> TritonSupportedBinaryElementwiseOps(
171178
PrimitiveType element_type, const se::GpuComputeCapability& gpu_version) {
172179
if (element_type == PrimitiveType::S4 || element_type == PrimitiveType::U16 ||
173180
element_type == PrimitiveType::F8E5M2 ||
174-
element_type == PrimitiveType::F8E4M3FN) {
181+
element_type == PrimitiveType::F8E4M3FN ||
182+
element_type == PrimitiveType::F8E4M3B11FNUZ ||
183+
element_type == PrimitiveType::F8E5M2FNUZ ||
184+
element_type == PrimitiveType::F8E4M3FNUZ) {
175185
return {};
176186
}
177187

@@ -220,7 +230,10 @@ absl::flat_hash_set<HloOpcode> TritonSupportedTernaryElementwiseOps(
220230
}
221231

222232
if (element_type == PrimitiveType::F8E5M2 ||
223-
element_type == PrimitiveType::F8E4M3FN) {
233+
element_type == PrimitiveType::F8E4M3FN ||
234+
element_type == PrimitiveType::F8E4M3B11FNUZ ||
235+
element_type == PrimitiveType::F8E5M2FNUZ ||
236+
element_type == PrimitiveType::F8E4M3FNUZ) {
224237
return {HloOpcode::kSelect};
225238
}
226239

@@ -248,7 +261,10 @@ CodegenDecision CanTritonHandleReduce(
248261
const HloReduceInstruction& reduce,
249262
const se::GpuComputeCapability& gpu_version) {
250263
if (reduce.shape().element_type() == PrimitiveType::F8E4M3FN ||
251-
reduce.shape().element_type() == PrimitiveType::F8E5M2) {
264+
reduce.shape().element_type() == PrimitiveType::F8E5M2 ||
265+
reduce.shape().element_type() == PrimitiveType::F8E5M2FNUZ ||
266+
reduce.shape().element_type() == PrimitiveType::F8E4M3FNUZ ||
267+
reduce.shape().element_type() == PrimitiveType::F8E4M3B11FNUZ) {
252268
return CodegenDecision::Forbid(
253269
"F8E4M3FN and F8E5M2 are not supported for reductions.");
254270
}
@@ -296,7 +312,8 @@ absl::Status CheckSupportedCheckDotDimensions(const HloDotInstruction& dot) {
296312
return absl::OkStatus();
297313
}
298314

299-
bool IsSupportedDotAlgorithm(PrecisionConfig::Algorithm algorithm) {
315+
bool IsSupportedDotAlgorithm(PrecisionConfig::Algorithm algorithm,
316+
const se::GpuComputeCapability& gpu_version) {
300317
switch (algorithm) {
301318
case PrecisionConfig::ALG_UNSET:
302319
case PrecisionConfig::ALG_DOT_F16_F16_F16:
@@ -309,8 +326,13 @@ bool IsSupportedDotAlgorithm(PrecisionConfig::Algorithm algorithm) {
309326
case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
310327
case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3:
311328
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9:
312-
return true;
329+
if (!std::holds_alternative<se::RocmComputeCapability>(gpu_version)) {
330+
return true;
331+
}
313332
case PrecisionConfig::ALG_DOT_BF16_BF16_BF16:
333+
if (std::holds_alternative<se::RocmComputeCapability>(gpu_version)) {
334+
return true;
335+
}
314336
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32:
315337
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM:
316338
default:
@@ -336,7 +358,15 @@ CodegenDecision AreTypesSupportedByAlgUnsetDot(
336358
}
337359
}
338360

339-
auto supported_float_types = {BF16, F16, F32, F64, F8E5M2, F8E4M3FN};
361+
if (input_type == F8E4M3B11FNUZ || result_type == F8E4M3B11FNUZ ||
362+
input_type == F64) {
363+
if (std::holds_alternative<se::RocmComputeCapability>(gpu_version)) {
364+
return CodegenDecision::Forbid(
365+
"Dot operation for F8E4M3B11FNUZ is not supported on ROCM.");
366+
}
367+
}
368+
369+
auto supported_float_types = {BF16, F16, F32, F64, F8E5M2};
340370
if (absl::c_linear_search(supported_float_types, input_type)) {
341371
return CodegenDecision::Allow();
342372
}
@@ -405,6 +435,11 @@ CodegenDecision AreDotAlgorithmInputAndOutputConversionsSupported(
405435
return forbid("Unsupported BF16 on GPUs before Blackwell");
406436
}
407437

438+
if (allowed_operands_types_or->front() == PrimitiveType::F64 &&
439+
std::holds_alternative<se::RocmComputeCapability>(gpu_version)) {
440+
return forbid("Unsupported result conversion");
441+
}
442+
408443
if (allowed_operands_types_or->size() != 1) {
409444
if (lhs_type == rhs_type &&
410445
absl::c_linear_search(*allowed_operands_types_or, lhs_type)) {
@@ -467,7 +502,7 @@ CodegenDecision IsTritonSupportedDot(
467502
const PrecisionConfig& precision_config = dot.precision_config();
468503
const PrecisionConfig::Algorithm algorithm = precision_config.algorithm();
469504

470-
if (!IsSupportedDotAlgorithm(algorithm)) {
505+
if (!IsSupportedDotAlgorithm(algorithm, gpu_version)) {
471506
return CodegenDecision::Forbid(
472507
absl::StrCat("Unsupported dot algorithm: ",
473508
PrecisionConfig::Algorithm_Name(algorithm)));
@@ -625,6 +660,8 @@ CodegenDecision IsTritonSupportedInstructionImpl(
625660
return CodegenDecision(
626661
element_type != PrimitiveType::F8E4M3FN &&
627662
element_type != PrimitiveType::F8E5M2 &&
663+
element_type != PrimitiveType::F8E4M3FNUZ &&
664+
element_type != PrimitiveType::F8E5M2FNUZ &&
628665
element_type != PrimitiveType::S4,
629666
"F8E4M3FN, F8E5M2 and S4 are not supported for iota.");
630667
}

0 commit comments

Comments
 (0)