Skip to content

Commit 3febc29

Browse files
zahiqbalcj401-amdzoranjovanovic-nsalekstheod
authored
Triton fixes porting from v0.6.0 (#389)
* rocprof-sdk addition, upstream PR: openxla/pull/29769 Squash following commits.. Update rocprofiler-sdk (v3) along with roctracer (v1) for rocm-jaxlib-v0.6.0 (#302) * update for integration of rocprofiler-sdk (along with roctracer as a backup based on bazel_options from CLI) (cherry picked from commit 7775dd0) use VLOG(2) to replace LOG(INFO), so PGLE has no verbose info (#357) (cherry picked from commit 5950125) update with kernel details for rocm-7.x (#364) * update with kernel details for rocm-7.x (cherry picked from commit 5597c0d) update to remove previously hard-coded rocprofiler-sdk path (#369) * update to remove previously hard-coded rocprofiler-sdk path and add skip_rocprofiler_sdk to avoid loading `rocprofiler-sdk` (cherry picked from commit ff74b5f) * fixed buffer comparator test * misc fixes ported from rocm-jaxlib-v0.6.0 --------- Co-authored-by: Pavel Emeliyanenko <[email protected]> (cherry picked from commit f013645) (cherry picked from commit b03cd94) Added support for waves_per_eu function attribute. (#181) (cherry picked from commit bc1d816) (cherry picked from commit d3f94e9) removed two line change (revert of half of the openxla#25959 commit (cherry picked from commit 109e138) Fixes for jax 0.6.0 (#207) * Add fixes for jax plugin 0.6.0 Drop NEEDED linking to unnecessary libs. These are loaded by amdhipruntime and not us. Fix missing NEEDED on MIOpen shared object. * Minor rocblas related changes for rocm 70 (cherry picked from commit 0de7d49) --------- Co-authored-by: Zoran Jovanovic <[email protected]> (cherry picked from commit 28f10a0) Add hipBLASLt support for gfx11. (#301) (cherry picked from commit f814bff) Add bf16 starting from gfx11, bugfix & optimize RocmComputeCapability (#303) * Bugfix and improve device_description.h::RocmComputeCompatibility * Enable ALG_DOT_BF16* on rocm with HW support (cherry picked from commit 510ea06) [ROCm] Use bundled bitcode files (#196) Also trim bitcode file list to ockl.bc and ocml.bc only. (cherry picked from commit fc9e3c3) Add MIOPEN_FIND_ENFORCE For ROCm 7 for convolution gemms (#312) * Add MIOPEN_FIND_ENFORCE For ROCm 7 for convolution gemms * Exclude failing CollectiveOpsE2E tests (cherry picked from commit fb6ddfb) Restore RocmComputeCapability:: gfx11_rx7900() and gfx12_rx8900() methods (#333) At least gfx11_rx7900() is still needed for TF build. (cherry picked from commit 13c3de1) Make device_count_ atomic (#343) * Make device_count_ atomic * Use relaxed memory order * Fix build error (cherry picked from commit 8513f2d) fix hardcoded max registers (#345) (cherry picked from commit f3e170a) fix hardcoded ecc enabled (#348) (cherry picked from commit 9cfa74a) remove reserved memory (#349) (cherry picked from commit 0015d0e) Add rocm_dev config for remote caching (#353) (cherry picked from commit c815420) added rocm7 support to EnablePeerAccess (#347) * added rocm7 support to EnablePeerAccess * use wrap namespace, clang-format and add comments (cherry picked from commit 85548a7) [ROCm] Disable Cudnn fusions (#358) (cherry picked from commit edab8b2) * Ported all triton related changes from v0.6.0 to v0.7.1 (cherry picked from commit 1851bcc) Disable softmax triton fusion if triton gemm is off (#281) * Disable softmax rewriter triton if triton gemm is disabled * Add specific flag to enable triton softmax fusion * Address review comments (cherry picked from commit 51a7f4b) [ROCm][Triton] Disable transposed load in certain conditions (cherry picked from commit 50860e9) Enable unit tests that pass after fixing some Triton related issues. (#285) * Enable unit tests that pass after fixing some Triton related issues. * fusion_emitter_device_legacy_test still fails on MI200 (cherry picked from commit 97dd565) Rocm jaxlib v0.6.0 triton support ut (#279) * Fixed triton/support_test - no fmfa. * Fix issue with rounding mode in accelerate amd matmul. * Fixed issues with usage of mfma in support_test. (cherry picked from commit 44f7d87) Restore gpu_triton_custom_call_test (#262) (cherry picked from commit 32eafa4) Skipped CanNotEmitTritonCustomCallOnPreAmpereGpu test for ROCM. (cherry picked from commit 56ec7ec) (cherry picked from commit b1f3e9f) fixed createTritonAMDGPULowerInstructionSchedHintsPass (#179) (cherry picked from commit 8517a3a) (cherry picked from commit c62e47d) fixed bazel build issue --------- Co-authored-by: Chunyu Jin <[email protected]> Co-authored-by: zoranjovanovic-ns <[email protected]> Co-authored-by: Alex <[email protected]>
1 parent 30c0943 commit 3febc29

File tree

13 files changed

+265
-29
lines changed

13 files changed

+265
-29
lines changed

build_tools/rocm/run_xla.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ GPU_NAME=(`rocminfo | grep -m 1 gfx`)
4545
GPU_NAME=${GPU_NAME[1]}
4646

4747
EXCLUDED_TESTS=(
48+
# //xla/service/gpu/tests:gpu_kernel_tiling_test_gpu_amd_any
49+
GpuKernelTilingTest.ColumnReductionWithLayoutChangeTiled
50+
GpuKernelTilingTest.ReductionInputTooLarge
4851
# //xla/pjrt/c:pjrt_c_api_gpu_test_gpu_amd_any
4952
PjrtCAPIGpuExtensionTest.TritonCompile
5053
# //xla/backends/gpu/codegen/triton:fusion_emitter_device_test_gpu_amd_any
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
From d539916e4d49cca93f54a5f99f7822050205432c Mon Sep 17 00:00:00 2001
2+
From: Jungwook Park <[email protected]>
3+
Date: Thu, 7 Aug 2025 06:34:49 -0500
4+
Subject: [PATCH] [AMD] Quick fix disabling transposed load used as different
5+
type.
6+
7+
Disabling transposedLoad if dot is using it as a different element type.
8+
Otherwise it's picking wrong vectorsize when lowering.
9+
---
10+
.../lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp | 20 +++++++++++++++++++
11+
1 file changed, 20 insertions(+)
12+
13+
diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp
14+
index 661a17678..6bda3a818 100644
15+
--- a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp
16+
+++ b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp
17+
@@ -214,6 +214,26 @@ private:
18+
return false;
19+
}
20+
21+
+ // transposed load can be used only when it's consumed by dot with the
22+
+ // loaded data type.
23+
+ int opIdx = 0;
24+
+ triton::gpu::LocalLoadOp lLoad = cast<triton::gpu::LocalLoadOp>(localLoad);
25+
+ if (auto dotEnc = lLoad.getSrc().getType().getEncoding())
26+
+ opIdx = cast<triton::gpu::DotOperandEncodingAttr>(dotEnc).getOpIdx();
27+
+ else
28+
+ return false;
29+
+
30+
+ SetVector<Operation *> slice;
31+
+ getForwardSlice(localLoad, &slice);
32+
+ for (auto op : slice) {
33+
+ if (auto dotOp = dyn_cast<triton::DotOp>(op)) {
34+
+ auto inputMat = (opIdx == 0) ? dotOp.getA() : dotOp.getB();
35+
+ auto bitwidthMat = inputMat.getType().getElementTypeBitWidth();
36+
+ if (bitwidth != bitwidthMat)
37+
+ return false;
38+
+ }
39+
+ }
40+
+
41+
return true;
42+
}
43+
44+
--
45+
2.34.1
46+
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
--- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp
2+
+++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp
3+
@@ -1068,7 +1068,10 @@ public:
4+
if (isFloat(srcElTy) && isFloat(dstElTy)) {
5+
auto rmode =
6+
RoundingModeAttr::get(rewriter.getContext(), RoundingMode::RTNE);
7+
- return rewriter.create<FpToFpOp>(loc, dstTy, v, rmode);
8+
+ if (dstElTy.getIntOrFloatBitWidth() < srcElTy.getIntOrFloatBitWidth()) {
9+
+ return rewriter.create<FpToFpOp>(loc, dstTy, v, rmode);
10+
+ }
11+
+ return rewriter.create<FpToFpOp>(loc, dstTy, v);
12+
}
13+
if (!isFloat(srcElTy) && isFloat(dstElTy))
14+
return rewriter.create<arith::SIToFPOp>(loc, dstTy, v);
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
--- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp
2+
+++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp
3+
@@ -380,9 +380,17 @@ Value convertAndCastTensor(PatternRewriter &rewriter, Value value,
4+
else if (oldElemType.isF32() && newElemType.isF16())
5+
castedTensor =
6+
rewriter.create<arith::TruncFOp>(loc, castedType, convertedTensor);
7+
- else
8+
- castedTensor =
9+
- rewriter.create<tt::FpToFpOp>(loc, castedType, convertedTensor);
10+
+ else {
11+
+ if(oldElemType.getIntOrFloatBitWidth() > newElemType.getIntOrFloatBitWidth()) {
12+
+ auto rmode =
13+
+ RoundingModeAttr::get(rewriter.getContext(), RoundingMode::RTNE);
14+
+ castedTensor =
15+
+ rewriter.create<tt::FpToFpOp>(loc, castedType, convertedTensor, rmode);
16+
+ } else {
17+
+ castedTensor =
18+
+ rewriter.create<tt::FpToFpOp>(loc, castedType, convertedTensor);
19+
+ }
20+
+ }
21+
}
22+
return castedTensor;
23+
}
24+

third_party/triton/temporary/series.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,8 @@ temporary_patch_list = [
1919
"//third_party/triton:temporary/tutorial_fixes.patch",
2020
"//third_party/triton:temporary/ws_fix.patch",
2121
"//third_party/triton:temporary/ws_ub_fix.patch",
22+
"//third_party/triton:temporary/0001-AMD-Quick-fix-disabling-transposed-load-used-as-diff.patch",
23+
"//third_party/triton:temporary/accelerateamdmatmul.patch",
24+
"//third_party/triton:temporary/accelerateamdmatmul2.patch",
2225
# Add new patches just above this line
2326
]

xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,8 @@ absl::Status CreateTritonPipeline(
146146
pm->addPass(mlir::createCanonicalizerPass());
147147
pm->addPass(mlir::createCSEPass());
148148
pm->addPass(mlir::createSymbolDCEPass());
149-
if (/*(instruction_sched_variant=="none") == */ false) {
150-
pm->addPass(mt::createTritonAMDGPULowerInstructionSchedHintsPass(
151-
cc.gfx_version(), num_stages));
152-
}
149+
pm->addPass(mt::createTritonAMDGPULowerInstructionSchedHintsPass(
150+
cc.gfx_version(), num_stages));
153151
pm->addPass(mt::createConvertBuiltinFuncToLLVMPass(/*ftz=*/true));
154152
// There is no clusters in ROCm for now.
155153
out_cluster_info.clusterDimX = 1;

xla/backends/gpu/codegen/triton/emitter_helpers.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ absl::StatusOr<Type> TritonType(EmitterLocOpBuilder& b, PrimitiveType t) {
106106
return b.getType<mlir::Float8E5M2Type>();
107107
case F8E4M3FN:
108108
return b.getType<mlir::Float8E4M3FNType>();
109+
case F8E4M3B11FNUZ:
110+
return b.getType<mlir::Float8E4M3B11FNUZType>();
111+
case F8E5M2FNUZ:
112+
return b.getType<mlir::Float8E5M2FNUZType>();
113+
case F8E4M3FNUZ:
114+
return b.getType<mlir::Float8E4M3FNUZType>();
109115
default:
110116
return absl::UnimplementedError(
111117
absl::StrCat("This type is not supported yet: ",
@@ -126,6 +132,9 @@ absl::StatusOr<PrimitiveType> GetPrimitiveType(Type t) {
126132
if (t.isInteger(1)) return PRED;
127133
if (mlir::isa<mlir::Float8E5M2Type>(t)) return F8E5M2;
128134
if (mlir::isa<mlir::Float8E4M3FNType>(t)) return F8E4M3FN;
135+
if (mlir::isa<mlir::Float8E4M3B11FNUZType>(t)) return F8E4M3B11FNUZ;
136+
if (mlir::isa<mlir::Float8E5M2FNUZType>(t)) return F8E5M2FNUZ;
137+
if (mlir::isa<mlir::Float8E4M3FNUZType>(t)) return F8E4M3FNUZ;
129138
return absl::UnimplementedError("Unsupported type in getPrimitiveType.\n");
130139
}
131140

xla/backends/gpu/codegen/triton/support.cc

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,12 @@ bool IsTritonSupportedDataType(PrimitiveType type,
6262
case F64:
6363
return true;
6464
case F8E5M2:
65-
case F8E4M3FN:
66-
return std::holds_alternative<se::CudaComputeCapability>(gpu_version);
65+
case F8E4M3FN:
66+
return std::holds_alternative<se::CudaComputeCapability>(gpu_version) ||
67+
std::holds_alternative<se::RocmComputeCapability>(gpu_version);
68+
case F8E5M2FNUZ:
69+
case F8E4M3FNUZ:
70+
return std::holds_alternative<se::RocmComputeCapability>(gpu_version);
6771
case BF16:
6872
return std::holds_alternative<se::CudaComputeCapability>(gpu_version) ||
6973
(std::holds_alternative<se::RocmComputeCapability>(gpu_version) &&
@@ -92,7 +96,10 @@ absl::flat_hash_set<HloOpcode> TritonSupportedUnaryElementwiseOps(
9296
absl::flat_hash_set<HloOpcode> ret{HloOpcode::kAbs, HloOpcode::kCopy};
9397

9498
if (element_type != PrimitiveType::F8E5M2 &&
95-
element_type != PrimitiveType::F8E4M3FN) {
99+
element_type != PrimitiveType::F8E4M3FN &&
100+
element_type != PrimitiveType::F8E4M3B11FNUZ &&
101+
element_type != PrimitiveType::F8E5M2FNUZ &&
102+
element_type != PrimitiveType::F8E4M3FNUZ) {
96103
ret.insert(HloOpcode::kNegate);
97104
}
98105

@@ -171,7 +178,10 @@ absl::flat_hash_set<HloOpcode> TritonSupportedBinaryElementwiseOps(
171178
PrimitiveType element_type, const se::GpuComputeCapability& gpu_version) {
172179
if (element_type == PrimitiveType::S4 || element_type == PrimitiveType::U16 ||
173180
element_type == PrimitiveType::F8E5M2 ||
174-
element_type == PrimitiveType::F8E4M3FN) {
181+
element_type == PrimitiveType::F8E4M3FN ||
182+
element_type == PrimitiveType::F8E4M3B11FNUZ ||
183+
element_type == PrimitiveType::F8E5M2FNUZ ||
184+
element_type == PrimitiveType::F8E4M3FNUZ) {
175185
return {};
176186
}
177187

@@ -220,7 +230,10 @@ absl::flat_hash_set<HloOpcode> TritonSupportedTernaryElementwiseOps(
220230
}
221231

222232
if (element_type == PrimitiveType::F8E5M2 ||
223-
element_type == PrimitiveType::F8E4M3FN) {
233+
element_type == PrimitiveType::F8E4M3FN ||
234+
element_type == PrimitiveType::F8E4M3B11FNUZ ||
235+
element_type == PrimitiveType::F8E5M2FNUZ ||
236+
element_type == PrimitiveType::F8E4M3FNUZ) {
224237
return {HloOpcode::kSelect};
225238
}
226239

@@ -248,7 +261,10 @@ CodegenDecision CanTritonHandleReduce(
248261
const HloReduceInstruction& reduce,
249262
const se::GpuComputeCapability& gpu_version) {
250263
if (reduce.shape().element_type() == PrimitiveType::F8E4M3FN ||
251-
reduce.shape().element_type() == PrimitiveType::F8E5M2) {
264+
reduce.shape().element_type() == PrimitiveType::F8E5M2 ||
265+
reduce.shape().element_type() == PrimitiveType::F8E5M2FNUZ ||
266+
reduce.shape().element_type() == PrimitiveType::F8E4M3FNUZ ||
267+
reduce.shape().element_type() == PrimitiveType::F8E4M3B11FNUZ) {
252268
return CodegenDecision::Forbid(
253269
"F8E4M3FN and F8E5M2 are not supported for reductions.");
254270
}
@@ -296,7 +312,8 @@ absl::Status CheckSupportedCheckDotDimensions(const HloDotInstruction& dot) {
296312
return absl::OkStatus();
297313
}
298314

299-
bool IsSupportedDotAlgorithm(PrecisionConfig::Algorithm algorithm) {
315+
bool IsSupportedDotAlgorithm(PrecisionConfig::Algorithm algorithm,
316+
const se::GpuComputeCapability& gpu_version) {
300317
switch (algorithm) {
301318
case PrecisionConfig::ALG_UNSET:
302319
case PrecisionConfig::ALG_DOT_F16_F16_F16:
@@ -309,8 +326,13 @@ bool IsSupportedDotAlgorithm(PrecisionConfig::Algorithm algorithm) {
309326
case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
310327
case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3:
311328
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9:
312-
return true;
329+
if (!std::holds_alternative<se::RocmComputeCapability>(gpu_version)) {
330+
return true;
331+
}
313332
case PrecisionConfig::ALG_DOT_BF16_BF16_BF16:
333+
if (std::holds_alternative<se::RocmComputeCapability>(gpu_version)) {
334+
return true;
335+
}
314336
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32:
315337
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM:
316338
default:
@@ -336,7 +358,15 @@ CodegenDecision AreTypesSupportedByAlgUnsetDot(
336358
}
337359
}
338360

339-
auto supported_float_types = {BF16, F16, F32, F64, F8E5M2, F8E4M3FN};
361+
if (input_type == F8E4M3B11FNUZ || result_type == F8E4M3B11FNUZ ||
362+
input_type == F64) {
363+
if (std::holds_alternative<se::RocmComputeCapability>(gpu_version)) {
364+
return CodegenDecision::Forbid(
365+
"Dot operation for F8E4M3B11FNUZ is not supported on ROCM.");
366+
}
367+
}
368+
369+
auto supported_float_types = {BF16, F16, F32, F64, F8E5M2};
340370
if (absl::c_linear_search(supported_float_types, input_type)) {
341371
return CodegenDecision::Allow();
342372
}
@@ -405,6 +435,11 @@ CodegenDecision AreDotAlgorithmInputAndOutputConversionsSupported(
405435
return forbid("Unsupported BF16 on GPUs before Blackwell");
406436
}
407437

438+
if (allowed_operands_types_or->front() == PrimitiveType::F64 &&
439+
std::holds_alternative<se::RocmComputeCapability>(gpu_version)) {
440+
return forbid("Unsupported result conversion");
441+
}
442+
408443
if (allowed_operands_types_or->size() != 1) {
409444
if (lhs_type == rhs_type &&
410445
absl::c_linear_search(*allowed_operands_types_or, lhs_type)) {
@@ -467,7 +502,7 @@ CodegenDecision IsTritonSupportedDot(
467502
const PrecisionConfig& precision_config = dot.precision_config();
468503
const PrecisionConfig::Algorithm algorithm = precision_config.algorithm();
469504

470-
if (!IsSupportedDotAlgorithm(algorithm)) {
505+
if (!IsSupportedDotAlgorithm(algorithm, gpu_version)) {
471506
return CodegenDecision::Forbid(
472507
absl::StrCat("Unsupported dot algorithm: ",
473508
PrecisionConfig::Algorithm_Name(algorithm)));
@@ -625,6 +660,8 @@ CodegenDecision IsTritonSupportedInstructionImpl(
625660
return CodegenDecision(
626661
element_type != PrimitiveType::F8E4M3FN &&
627662
element_type != PrimitiveType::F8E5M2 &&
663+
element_type != PrimitiveType::F8E4M3FNUZ &&
664+
element_type != PrimitiveType::F8E5M2FNUZ &&
628665
element_type != PrimitiveType::S4,
629666
"F8E4M3FN, F8E5M2 and S4 are not supported for iota.");
630667
}

0 commit comments

Comments
 (0)