Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build_tools/lint/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
"multi_gpu_h100": (
"Used by `xla_test` to signal that multiple H100s are needed."
),
"skip_rocprofiler_sdk": "used to skip rocmtracer test as it calls rocprofiler-sdk via rocprofiler_force_configure",
}


Expand Down
6 changes: 5 additions & 1 deletion build_tools/rocm/run_xla.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ GPU_NAME=(`rocminfo | grep -m 1 gfx`)
GPU_NAME=${GPU_NAME[1]}

EXCLUDED_TESTS=(
# //xla/service/gpu/tests:gpu_kernel_tiling_test_gpu_amd_any
GpuKernelTilingTest.ColumnReductionWithLayoutChangeTiled
GpuKernelTilingTest.ReductionInputTooLarge
# //xla/pjrt/c:pjrt_c_api_gpu_test_gpu_amd_any
PjrtCAPIGpuExtensionTest.TritonCompile
# //xla/backends/gpu/codegen/triton:fusion_emitter_device_test_gpu_amd_any
Expand Down Expand Up @@ -92,6 +95,7 @@ BasicDotAlgorithmEmitterTestSuite/BasicDotAlgorithmEmitterTest.BasicAlgorithmIsE
CommandBufferTests/CommandBufferTest.IndexConditional/*
CommandBufferTests/CommandBufferTest.WhileLoop/*
CommandBufferTests/CommandBufferTest.TrueFalseConditional/*
BufferComparatorTest.VeryLargeArray_Device_U8_Aligned
)

BAZEL_DISK_CACHE_SIZE=100G
Expand Down Expand Up @@ -147,4 +151,4 @@ bazel --bazelrc=build_tools/rocm/rocm_xla.bazelrc test \
# clean up bazel disk_cache
bazel shutdown \
--disk_cache=${BAZEL_DISK_CACHE_DIR} \
--experimental_disk_cache_gc_max_size=${BAZEL_DISK_CACHE_SIZE}
--experimental_disk_cache_gc_max_size=${BAZEL_DISK_CACHE_SIZE}
2 changes: 1 addition & 1 deletion build_tools/rocm/run_xla_multi_gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,4 @@ bazel --bazelrc=build_tools/rocm/rocm_xla.bazelrc test \
# clean up bazel disk_cache
bazel shutdown \
--disk_cache=${BAZEL_DISK_CACHE_DIR} \
--experimental_disk_cache_gc_max_size=${BAZEL_DISK_CACHE_SIZE}
--experimental_disk_cache_gc_max_size=${BAZEL_DISK_CACHE_SIZE}
13 changes: 13 additions & 0 deletions tensorflow.bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,19 @@ build:rocm_clang_official --action_env=TF_ROCM_CLANG="1"
build:rocm_clang_official --linkopt="-fuse-ld=lld"
build:rocm_clang_official --host_linkopt="-fuse-ld=lld"

build:rocm_dev --remote_upload_local_results=false
build:rocm_dev --remote_cache="https://wardite.cluster.engflow.com"

build:rocm_rbe --bes_backend="grpcs://wardite.cluster.engflow.com"
build:rocm_rbe --bes_results_url="https://wardite.cluster.engflow.com/invocation/"
build:rocm_rbe --remote_executor="grpcs://wardite.cluster.engflow.com"
build:rocm_rbe --host_platform="//platform/linux_x64"
build:rocm_rbe --extra_execution_platforms="//platform/linux_x64"
build:rocm_rbe --platforms="//platform/linux_x64"
build:rocm_rbe --bes_timeout=600s
build:rocm_rbe --tls_client_certificate="/tf/certificates/ci-cert.crt"
build:rocm_rbe --tls_client_key="/tf/certificates/ci-cert.key"

build:rocm_ci --config=rocm_clang_official

build:rocm_ci_hermetic --config=rocm_clang_official
Expand Down
19 changes: 15 additions & 4 deletions third_party/gpus/rocm/BUILD.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,11 @@ cc_library(
],
strip_include_prefix = "%{rocm_root}",
deps = [
":amd_comgr",
":hsa_rocr",
":rocm_config",
":rocm_smi",
":rocprofiler_register",
":system_libs",
],
visibility = ["//visibility:public"],
)

# Used by jax_rocm_plugin to minimally link to hip runtime.
Expand Down Expand Up @@ -259,8 +257,8 @@ cc_library(
cc_library(
name = "miopen",
hdrs = glob(["%{rocm_root}/include/miopen/**"]),
srcs = glob(["%{rocm_root}/lib/libMIOpen*.so*"]),
data = glob([
"%{rocm_root}/lib/libMIOpen*.so*",
"%{rocm_root}/share/miopen/**",
]),
include_prefix = "rocm",
Expand Down Expand Up @@ -349,6 +347,19 @@ cc_library(
deps = [":rocm_config"],
)

cc_library(
name = "rocprofiler-sdk",
srcs = glob(["%{rocm_root}/lib/librocprofiler-sdk*.so*"]),
hdrs = glob(["%{rocm_root}/include/rocprofiler-sdk/**"]),
include_prefix = "rocm",
includes = [
"%{rocm_root}/include/",
],
strip_include_prefix = "%{rocm_root}",
visibility = ["//visibility:public"],
deps = [":rocm_config"],
)

cc_library(
name = "rocsolver",
srcs = glob(["%{rocm_root}/lib/librocsolver*.so*"]),
Expand Down
1 change: 1 addition & 0 deletions third_party/gpus/rocm_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def _find_libs(repository_ctx, rocm_config, miopen_path, rccl_path, bash_bin):
("rocsolver", rocm_config.rocm_toolkit_path),
("hipfft", rocm_config.rocm_toolkit_path),
("rocrand", rocm_config.rocm_toolkit_path),
("rocprofiler-sdk", rocm_config.rocm_toolkit_path),
]
]
if int(rocm_config.rocm_version_number) >= 40500:
Expand Down
1 change: 1 addition & 0 deletions third_party/rocm_device_libs/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# copybara:uncomment package(default_applicable_licenses = ["//third_party/tensorflow:license"])
96 changes: 96 additions & 0 deletions third_party/rocm_device_libs/build_defs.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
load("@bazel_skylib//lib:paths.bzl", "paths")

def bitcode_library(
name,
srcs = [],
hdrs = [],
file_specific_flags = {},
**kwargs
):
"""Builds a bitcode library
Args:
name: Unique name of the build rule.
srcs: List of source files (*.cl, *.ll).
hdrs: List of header files (*.h).
file_specific_flags: Per-file dict of flags to be passed to clang.
**kwargs: Attributes relevant for a common rule.
"""
clang_tool = "@llvm-project//clang:clang"
clang_include = "@llvm-raw//:clang/lib/Headers"
llvm_link_tool = "@llvm-project//llvm:llvm-link"
opt_tool = "@llvm-project//llvm:opt"
prepare_builtins_tool = ":prepare_builtins"

include_paths = dict([(paths.dirname(h), None) for h in hdrs]).keys()
includes = " ".join(["-I$(location {})".format(inc) for inc in include_paths])
flags = ("-fcolor-diagnostics -Werror -Wno-error=atomic-alignment -x cl -Xclang " +
"-cl-std=CL2.0 --target=amdgcn-amd-amdhsa -fvisibility=hidden -fomit-frame-pointer " +
"-Xclang -finclude-default-header -Xclang -fexperimental-strict-floating-point " +
"-Xclang -fdenormal-fp-math=dynamic -Xclang -Qn " +
"-nogpulib -cl-no-stdinc -Xclang -mcode-object-version=none")

link_inputs = []

for src in srcs:
filename = paths.basename(src)
(basename, _, ext) = filename.partition(".")

if (ext == "ll"):
link_inputs.append(src)
continue

out = basename + ".bc"
link_inputs.append(out)
extra_flags = " ".join(file_specific_flags.get(filename,[]))
native.genrule(
name = "compile_" + basename,
srcs = [src] + hdrs + include_paths,
outs = [out],
#TODO(rocm): Ugly hack to access bultin clang includes.
cmd = "$(location {}) -I$(execpath {}).runfiles/llvm-project/clang/staging/include/ {} {} {} -emit-llvm -c $(location {}) -o $@".format(
clang_tool, clang_tool, includes, flags, extra_flags, src),
tools = [clang_tool],
message = "Compiling {} ...".format(filename),
)

link_message = "Linking {}.bc ...".format(name)

prelink_out = name + ".link0.lib.bc"
native.genrule(
name = "prelink_" + name,
srcs = link_inputs,
outs = [prelink_out],
cmd = "$(location {}) $(SRCS) -o $@".format(llvm_link_tool),
tools = [llvm_link_tool],
message = link_message,
)

internalize_out = name + ".lib.bc"
native.genrule(
name = "internalize_" + name,
srcs = [prelink_out],
outs = [internalize_out],
cmd = "$(location {}) -internalize -only-needed $< -o $@".format(llvm_link_tool),
tools = [llvm_link_tool],
message = link_message,
)

strip_out = name + ".strip.bc"
native.genrule(
name = "strip_" + name,
srcs = [internalize_out],
outs = [strip_out],
cmd = "$(location {}) -passes=amdgpu-unify-metadata,strip -o $@ $<".format(opt_tool),
tools = [opt_tool],
message = link_message,
)

native.genrule(
name = name,
srcs = [strip_out],
outs = [name + ".bc"],
cmd = "$(location {}) -o $@ $<".format(prepare_builtins_tool),
tools = [prepare_builtins_tool],
message = link_message,
)
18 changes: 18 additions & 0 deletions third_party/rocm_device_libs/prepare_builtins.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
diff --git a/utils/prepare-builtins/prepare-builtins.cpp b/utils/prepare-builtins/prepare-builtins.cpp
index 7fc9d06dab7d..2a93638c3f8f 100644
--- a/utils/prepare-builtins/prepare-builtins.cpp
+++ b/utils/prepare-builtins/prepare-builtins.cpp
@@ -73,6 +73,13 @@ int main(int argc, char **argv) {
return 1;
}

+ // Strip the OpenCL version metadata. There are a lot of linked
+ // modules in the library build, each spamming the same
+ // version. This may also report a different version than the user
+ // program is using. This should probably be uniqued when linking.
+ if (NamedMDNode *OCLVersion = M->getNamedMetadata("opencl.ocl.version"))
+ M->eraseNamedMetadata(OCLVersion);
+
// Set linkage of every external definition to linkonce_odr.
for (Module::iterator i = M->begin(), e = M->end(); i != e; ++i) {
if (!i->isDeclaration() && i->getLinkage() == GlobalValue::ExternalLinkage) {
62 changes: 62 additions & 0 deletions third_party/rocm_device_libs/rocm_device_libs.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
load("build_defs.bzl", "bitcode_library")

licenses(["notice"])

package(default_visibility = ["//visibility:public"])

exports_files([
"LICENSE.TXT",
])

cc_binary(
name = "prepare_builtins",
srcs = glob([
"utils/prepare-builtins/*.cpp",
"utils/prepare-builtins/*.h",
]),
copts = [
"-fno-rtti -fno-exceptions",
],
deps = [
"@llvm-project//llvm:BitReader",
"@llvm-project//llvm:BitWriter",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:IRReader",
"@llvm-project//llvm:Support",
],
visibility = ["//visibility:private"],
)

bitcode_library(
name = "ocml",
srcs = glob([
"ocml/src/*.cl"
]),
hdrs = glob([
"ocml/src/*.h",
"ocml/inc/*.h",
"irif/inc/*.h",
"oclc/inc/*.h",
]),
file_specific_flags = {
"native_logF.cl": ["-fapprox-func"],
"native_expF.cl": ["-fapprox-func"],
"sqrtF.cl": ["-cl-fp32-correctly-rounded-divide-sqrt"],
},
)

bitcode_library(
name = "ockl",
srcs = glob([
"ockl/src/*.cl",
"ockl/src/*.ll",
]),
hdrs = glob([
"ockl/inc/*.h",
"irif/inc/*.h",
"oclc/inc/*.h",
]),
file_specific_flags = {
"gaaf.cl": ["-munsafe-fp-atomics"],
},
)
22 changes: 22 additions & 0 deletions third_party/rocm_device_libs/workspace.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Provides the repository macro to import Rocm-Device-Libs"""

load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
"""Imports Rocm-Device-Libs."""
LLVM_COMMIT = "0cf1859d038376421b4cd597e3df90d37cfca06e"
LLVM_SHA256 = "0374d1efa0f049d2d1c24c4d86029b006cb5594cc0a1b6a18c49fb094c29cd29"

tf_http_archive(
name = "rocm_device_libs",
sha256 = LLVM_SHA256,
strip_prefix = "llvm-project-{commit}/amd/device-libs".format(commit = LLVM_COMMIT),
urls = tf_mirror_urls("https://github.com/ROCm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)),
build_file = "//third_party/rocm_device_libs:rocm_device_libs.BUILD",
patch_file = [
"//third_party/rocm_device_libs:prepare_builtins.patch",
],
link_files = {
"//third_party/rocm_device_libs:build_defs.bzl": "build_defs.bzl",
},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
From d539916e4d49cca93f54a5f99f7822050205432c Mon Sep 17 00:00:00 2001
From: Jungwook Park <[email protected]>
Date: Thu, 7 Aug 2025 06:34:49 -0500
Subject: [PATCH] [AMD] Quick fix disabling transposed load used as different
type.

Disabling transposedLoad if dot is using it as a different element type.
Otherwise it's picking wrong vectorsize when lowering.
---
.../lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp | 20 +++++++++++++++++++
1 file changed, 20 insertions(+)

diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp
index 661a17678..6bda3a818 100644
--- a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp
+++ b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp
@@ -214,6 +214,26 @@ private:
return false;
}

+ // transposed load can be used only when it's consumed by dot with the
+ // loaded data type.
+ int opIdx = 0;
+ triton::gpu::LocalLoadOp lLoad = cast<triton::gpu::LocalLoadOp>(localLoad);
+ if (auto dotEnc = lLoad.getSrc().getType().getEncoding())
+ opIdx = cast<triton::gpu::DotOperandEncodingAttr>(dotEnc).getOpIdx();
+ else
+ return false;
+
+ SetVector<Operation *> slice;
+ getForwardSlice(localLoad, &slice);
+ for (auto op : slice) {
+ if (auto dotOp = dyn_cast<triton::DotOp>(op)) {
+ auto inputMat = (opIdx == 0) ? dotOp.getA() : dotOp.getB();
+ auto bitwidthMat = inputMat.getType().getElementTypeBitWidth();
+ if (bitwidth != bitwidthMat)
+ return false;
+ }
+ }
+
return true;
}

--
2.34.1

14 changes: 14 additions & 0 deletions third_party/triton/temporary/accelerateamdmatmul.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
--- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp
+++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp
@@ -1068,7 +1068,10 @@ public:
if (isFloat(srcElTy) && isFloat(dstElTy)) {
auto rmode =
RoundingModeAttr::get(rewriter.getContext(), RoundingMode::RTNE);
- return rewriter.create<FpToFpOp>(loc, dstTy, v, rmode);
+ if (dstElTy.getIntOrFloatBitWidth() < srcElTy.getIntOrFloatBitWidth()) {
+ return rewriter.create<FpToFpOp>(loc, dstTy, v, rmode);
+ }
+ return rewriter.create<FpToFpOp>(loc, dstTy, v);
}
if (!isFloat(srcElTy) && isFloat(dstElTy))
return rewriter.create<arith::SIToFPOp>(loc, dstTy, v);
Loading
Loading