Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build_tools/lint/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
"multi_gpu_h100": (
"Used by `xla_test` to signal that multiple H100s are needed."
),
"skip_rocprofiler_sdk": "used to skip rocmtracer test as it calls rocprofiler-sdk via rocprofiler_force_configure",
}


Expand Down
3 changes: 2 additions & 1 deletion build_tools/rocm/run_xla.sh
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ BasicDotAlgorithmEmitterTestSuite/BasicDotAlgorithmEmitterTest.BasicAlgorithmIsE
CommandBufferTests/CommandBufferTest.IndexConditional/*
CommandBufferTests/CommandBufferTest.WhileLoop/*
CommandBufferTests/CommandBufferTest.TrueFalseConditional/*
BufferComparatorTest.VeryLargeArray_Device_U8_Aligned
)

BAZEL_DISK_CACHE_SIZE=100G
Expand Down Expand Up @@ -147,4 +148,4 @@ bazel --bazelrc=build_tools/rocm/rocm_xla.bazelrc test \
# clean up bazel disk_cache
bazel shutdown \
--disk_cache=${BAZEL_DISK_CACHE_DIR} \
--experimental_disk_cache_gc_max_size=${BAZEL_DISK_CACHE_SIZE}
--experimental_disk_cache_gc_max_size=${BAZEL_DISK_CACHE_SIZE}
2 changes: 1 addition & 1 deletion build_tools/rocm/run_xla_multi_gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,4 @@ bazel --bazelrc=build_tools/rocm/rocm_xla.bazelrc test \
# clean up bazel disk_cache
bazel shutdown \
--disk_cache=${BAZEL_DISK_CACHE_DIR} \
--experimental_disk_cache_gc_max_size=${BAZEL_DISK_CACHE_SIZE}
--experimental_disk_cache_gc_max_size=${BAZEL_DISK_CACHE_SIZE}
13 changes: 13 additions & 0 deletions tensorflow.bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,19 @@ build:rocm_clang_official --action_env=TF_ROCM_CLANG="1"
build:rocm_clang_official --linkopt="-fuse-ld=lld"
build:rocm_clang_official --host_linkopt="-fuse-ld=lld"

build:rocm_dev --remote_upload_local_results=false
build:rocm_dev --remote_cache="https://wardite.cluster.engflow.com"

build:rocm_rbe --bes_backend="grpcs://wardite.cluster.engflow.com"
build:rocm_rbe --bes_results_url="https://wardite.cluster.engflow.com/invocation/"
build:rocm_rbe --remote_executor="grpcs://wardite.cluster.engflow.com"
build:rocm_rbe --host_platform="//platform/linux_x64"
build:rocm_rbe --extra_execution_platforms="//platform/linux_x64"
build:rocm_rbe --platforms="//platform/linux_x64"
build:rocm_rbe --bes_timeout=600s
build:rocm_rbe --tls_client_certificate="/tf/certificates/ci-cert.crt"
build:rocm_rbe --tls_client_key="/tf/certificates/ci-cert.key"

build:rocm_ci --config=rocm_clang_official

build:rocm_ci_hermetic --config=rocm_clang_official
Expand Down
19 changes: 15 additions & 4 deletions third_party/gpus/rocm/BUILD.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,11 @@ cc_library(
],
strip_include_prefix = "%{rocm_root}",
deps = [
":amd_comgr",
":hsa_rocr",
":rocm_config",
":rocm_smi",
":rocprofiler_register",
":system_libs",
],
visibility = ["//visibility:public"],
)

# Used by jax_rocm_plugin to minimally link to hip runtime.
Expand Down Expand Up @@ -259,8 +257,8 @@ cc_library(
cc_library(
name = "miopen",
hdrs = glob(["%{rocm_root}/include/miopen/**"]),
srcs = glob(["%{rocm_root}/lib/libMIOpen*.so*"]),
data = glob([
"%{rocm_root}/lib/libMIOpen*.so*",
"%{rocm_root}/share/miopen/**",
]),
include_prefix = "rocm",
Expand Down Expand Up @@ -349,6 +347,19 @@ cc_library(
deps = [":rocm_config"],
)

cc_library(
name = "rocprofiler-sdk",
srcs = glob(["%{rocm_root}/lib/librocprofiler-sdk*.so*"]),
hdrs = glob(["%{rocm_root}/include/rocprofiler-sdk/**"]),
include_prefix = "rocm",
includes = [
"%{rocm_root}/include/",
],
strip_include_prefix = "%{rocm_root}",
visibility = ["//visibility:public"],
deps = [":rocm_config"],
)

cc_library(
name = "rocsolver",
srcs = glob(["%{rocm_root}/lib/librocsolver*.so*"]),
Expand Down
1 change: 1 addition & 0 deletions third_party/gpus/rocm_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def _find_libs(repository_ctx, rocm_config, miopen_path, rccl_path, bash_bin):
("rocsolver", rocm_config.rocm_toolkit_path),
("hipfft", rocm_config.rocm_toolkit_path),
("rocrand", rocm_config.rocm_toolkit_path),
("rocprofiler-sdk", rocm_config.rocm_toolkit_path),
]
]
if int(rocm_config.rocm_version_number) >= 40500:
Expand Down
1 change: 1 addition & 0 deletions third_party/rocm_device_libs/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# copybara:uncomment package(default_applicable_licenses = ["//third_party/tensorflow:license"])
96 changes: 96 additions & 0 deletions third_party/rocm_device_libs/build_defs.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
load("@bazel_skylib//lib:paths.bzl", "paths")

def bitcode_library(
name,
srcs = [],
hdrs = [],
file_specific_flags = {},
**kwargs
):
"""Builds a bitcode library

Args:
name: Unique name of the build rule.
srcs: List of source files (*.cl, *.ll).
hdrs: List of header files (*.h).
file_specific_flags: Per-file dict of flags to be passed to clang.
**kwargs: Attributes relevant for a common rule.
"""
clang_tool = "@llvm-project//clang:clang"
clang_include = "@llvm-raw//:clang/lib/Headers"
llvm_link_tool = "@llvm-project//llvm:llvm-link"
opt_tool = "@llvm-project//llvm:opt"
prepare_builtins_tool = ":prepare_builtins"

include_paths = dict([(paths.dirname(h), None) for h in hdrs]).keys()
includes = " ".join(["-I$(location {})".format(inc) for inc in include_paths])
flags = ("-fcolor-diagnostics -Werror -Wno-error=atomic-alignment -x cl -Xclang " +
"-cl-std=CL2.0 --target=amdgcn-amd-amdhsa -fvisibility=hidden -fomit-frame-pointer " +
"-Xclang -finclude-default-header -Xclang -fexperimental-strict-floating-point " +
"-Xclang -fdenormal-fp-math=dynamic -Xclang -Qn " +
"-nogpulib -cl-no-stdinc -Xclang -mcode-object-version=none")

link_inputs = []

for src in srcs:
filename = paths.basename(src)
(basename, _, ext) = filename.partition(".")

if (ext == "ll"):
link_inputs.append(src)
continue

out = basename + ".bc"
link_inputs.append(out)
extra_flags = " ".join(file_specific_flags.get(filename,[]))
native.genrule(
name = "compile_" + basename,
srcs = [src] + hdrs + include_paths,
outs = [out],
#TODO(rocm): Ugly hack to access bultin clang includes.
cmd = "$(location {}) -I$(execpath {}).runfiles/llvm-project/clang/staging/include/ {} {} {} -emit-llvm -c $(location {}) -o $@".format(
clang_tool, clang_tool, includes, flags, extra_flags, src),
tools = [clang_tool],
message = "Compiling {} ...".format(filename),
)

link_message = "Linking {}.bc ...".format(name)

prelink_out = name + ".link0.lib.bc"
native.genrule(
name = "prelink_" + name,
srcs = link_inputs,
outs = [prelink_out],
cmd = "$(location {}) $(SRCS) -o $@".format(llvm_link_tool),
tools = [llvm_link_tool],
message = link_message,
)

internalize_out = name + ".lib.bc"
native.genrule(
name = "internalize_" + name,
srcs = [prelink_out],
outs = [internalize_out],
cmd = "$(location {}) -internalize -only-needed $< -o $@".format(llvm_link_tool),
tools = [llvm_link_tool],
message = link_message,
)

strip_out = name + ".strip.bc"
native.genrule(
name = "strip_" + name,
srcs = [internalize_out],
outs = [strip_out],
cmd = "$(location {}) -passes=amdgpu-unify-metadata,strip -o $@ $<".format(opt_tool),
tools = [opt_tool],
message = link_message,
)

native.genrule(
name = name,
srcs = [strip_out],
outs = [name + ".bc"],
cmd = "$(location {}) -o $@ $<".format(prepare_builtins_tool),
tools = [prepare_builtins_tool],
message = link_message,
)
18 changes: 18 additions & 0 deletions third_party/rocm_device_libs/prepare_builtins.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
diff --git a/utils/prepare-builtins/prepare-builtins.cpp b/utils/prepare-builtins/prepare-builtins.cpp
index 7fc9d06dab7d..2a93638c3f8f 100644
--- a/utils/prepare-builtins/prepare-builtins.cpp
+++ b/utils/prepare-builtins/prepare-builtins.cpp
@@ -73,6 +73,13 @@ int main(int argc, char **argv) {
return 1;
}

+ // Strip the OpenCL version metadata. There are a lot of linked
+ // modules in the library build, each spamming the same
+ // version. This may also report a different version than the user
+ // program is using. This should probably be uniqued when linking.
+ if (NamedMDNode *OCLVersion = M->getNamedMetadata("opencl.ocl.version"))
+ M->eraseNamedMetadata(OCLVersion);
+
// Set linkage of every external definition to linkonce_odr.
for (Module::iterator i = M->begin(), e = M->end(); i != e; ++i) {
if (!i->isDeclaration() && i->getLinkage() == GlobalValue::ExternalLinkage) {
62 changes: 62 additions & 0 deletions third_party/rocm_device_libs/rocm_device_libs.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
load("build_defs.bzl", "bitcode_library")

licenses(["notice"])

package(default_visibility = ["//visibility:public"])

exports_files([
"LICENSE.TXT",
])

cc_binary(
name = "prepare_builtins",
srcs = glob([
"utils/prepare-builtins/*.cpp",
"utils/prepare-builtins/*.h",
]),
copts = [
"-fno-rtti -fno-exceptions",
],
deps = [
"@llvm-project//llvm:BitReader",
"@llvm-project//llvm:BitWriter",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:IRReader",
"@llvm-project//llvm:Support",
],
visibility = ["//visibility:private"],
)

bitcode_library(
name = "ocml",
srcs = glob([
"ocml/src/*.cl"
]),
hdrs = glob([
"ocml/src/*.h",
"ocml/inc/*.h",
"irif/inc/*.h",
"oclc/inc/*.h",
]),
file_specific_flags = {
"native_logF.cl": ["-fapprox-func"],
"native_expF.cl": ["-fapprox-func"],
"sqrtF.cl": ["-cl-fp32-correctly-rounded-divide-sqrt"],
},
)

bitcode_library(
name = "ockl",
srcs = glob([
"ockl/src/*.cl",
"ockl/src/*.ll",
]),
hdrs = glob([
"ockl/inc/*.h",
"irif/inc/*.h",
"oclc/inc/*.h",
]),
file_specific_flags = {
"gaaf.cl": ["-munsafe-fp-atomics"],
},
)
22 changes: 22 additions & 0 deletions third_party/rocm_device_libs/workspace.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Provides the repository macro to import Rocm-Device-Libs"""

load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
"""Imports Rocm-Device-Libs."""
LLVM_COMMIT = "0cf1859d038376421b4cd597e3df90d37cfca06e"
LLVM_SHA256 = "0374d1efa0f049d2d1c24c4d86029b006cb5594cc0a1b6a18c49fb094c29cd29"

tf_http_archive(
name = "rocm_device_libs",
sha256 = LLVM_SHA256,
strip_prefix = "llvm-project-{commit}/amd/device-libs".format(commit = LLVM_COMMIT),
urls = tf_mirror_urls("https://github.com/ROCm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)),
build_file = "//third_party/rocm_device_libs:rocm_device_libs.BUILD",
patch_file = [
"//third_party/rocm_device_libs:prepare_builtins.patch",
],
link_files = {
"//third_party/rocm_device_libs:build_defs.bzl": "build_defs.bzl",
},
)
2 changes: 2 additions & 0 deletions workspace2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ load("//third_party/py:python_configure.bzl", "python_configure")
load("//third_party/py/ml_dtypes:workspace.bzl", ml_dtypes = "repo")
load("//third_party/pybind11_abseil:workspace.bzl", pybind11_abseil = "repo")
load("//third_party/pybind11_bazel:workspace.bzl", pybind11_bazel = "repo")
load("//third_party/rocm_device_libs:workspace.bzl", rocm_device_libs = "repo")
load("//third_party/robin_map:workspace.bzl", robin_map = "repo")
load("//third_party/shardy:workspace.bzl", shardy = "repo")
load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo")
Expand Down Expand Up @@ -66,6 +67,7 @@ def _initialize_third_party():
nvshmem()
pybind11_abseil()
pybind11_bazel()
rocm_device_libs()
robin_map()
shardy()
stablehlo()
Expand Down
12 changes: 11 additions & 1 deletion xla/backends/gpu/codegen/emitters/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class ReductionFusion : public EmitterBase {
return IndexingMap::GetUndefined();
}

int64_t WarpSize() const {
virtual int64_t WarpSize() const {
return ::xla::gpu::WarpSize(analysis_.device_info());
}

Expand Down Expand Up @@ -198,6 +198,11 @@ class ColumnReductionFusion : public ReductionFusion {
public:
explicit ColumnReductionFusion(const HloFusionAnalysis& analysis);

int64_t WarpSize() const override {
// PAE HACK HACK
return 32;
}

protected:
llvm::SmallVector<mlir::Value> EmitReduction(
int group_id, EmitterState& state) const override;
Expand All @@ -218,6 +223,11 @@ class SmallColumnReductionFusion : public ReductionFusion {
public:
explicit SmallColumnReductionFusion(const HloFusionAnalysis& analysis);

int64_t WarpSize() const override {
// PAE HACK HACK
return 32;
}

protected:
llvm::SmallVector<mlir::Value> EmitReduction(
int group_id, EmitterState& state) const override;
Expand Down
7 changes: 4 additions & 3 deletions xla/backends/gpu/codegen/emitters/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,14 @@ using mlir::VectorType;
using mlir::func::FuncOp;
using mlir::func::ReturnOp;


namespace mt = ::mlir::tensor;
namespace mv = ::mlir::vector;

constexpr int kTileSize = 32;
constexpr int kNumRows = 4;
constexpr int kNumThreadsPerBlock = 128;
constexpr int kMaxVectorizedBytes = 4;
constexpr int kNumRows = 8;
constexpr int kNumThreadsPerBlock = kNumRows * kTileSize;
constexpr int kMaxVectorizedBytes = 16;

// Reads the 2D vector tile <vector_size x vector_size> from the shared memory
// at the given indices.
Expand Down
6 changes: 1 addition & 5 deletions xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,7 @@ absl::Status CreateTritonPipeline(

std::string GetLibdevicePath(const HloModuleConfig& hlo_config,
const se::DeviceDescription& device_info) {
std::string libdevice_dir = tsl::RocdlRoot();
auto compute_capability = device_info.rocm_compute_capability();
const std::string libdevice_path =
amdgpu::LibDevicePath(compute_capability.gcn_arch_name(), libdevice_dir);
return libdevice_path;
return "__builtin__";
}

} // namespace gpu
Expand Down
Loading