Skip to content

Commit 6743efd

Browse files
zoranjovanovic-nszahiqbal
authored andcommitted
misc fixes ported from rocm-jaxlib-v0.6.0
--------- Co-authored-by: Pavel Emeliyanenko <[email protected]> (cherry picked from commit f013645) (cherry picked from commit b03cd94) Added support for waves_per_eu function attribute. (#181) (cherry picked from commit bc1d816) (cherry picked from commit d3f94e9) removed two line change (revert of half of the openxla#25959 commit (cherry picked from commit 109e138) Fixes for jax 0.6.0 (#207) * Add fixes for jax plugin 0.6.0 Drop NEEDED linking to unnecessary libs. These are loaded by amdhipruntime and not us. Fix missing NEEDED on MIOpen shared object. * Minor rocblas related changes for rocm 70 (cherry picked from commit 0de7d49) --------- Co-authored-by: Zoran Jovanovic <[email protected]> (cherry picked from commit 28f10a0) Add hipBLASLt support for gfx11. (#301) (cherry picked from commit f814bff) Add bf16 starting from gfx11, bugfix & optimize RocmComputeCapability (#303) * Bugfix and improve device_description.h::RocmComputeCompatibility * Enable ALG_DOT_BF16* on rocm with HW support (cherry picked from commit 510ea06) [ROCm] Use bundled bitcode files (#196) Also trim bitcode file list to ockl.bc and ocml.bc only. (cherry picked from commit fc9e3c3) Add MIOPEN_FIND_ENFORCE For ROCm 7 for convolution gemms (#312) * Add MIOPEN_FIND_ENFORCE For ROCm 7 for convolution gemms * Exclude failing CollectiveOpsE2E tests (cherry picked from commit fb6ddfb) Restore RocmComputeCapability:: gfx11_rx7900() and gfx12_rx8900() methods (#333) At least gfx11_rx7900() is still needed for TF build. (cherry picked from commit 13c3de1) Make device_count_ atomic (#343) * Make device_count_ atomic * Use relaxed memory order * Fix build error (cherry picked from commit 8513f2d) fix hardcoded max registers (#345) (cherry picked from commit f3e170a) fix hardcoded ecc enabled (#348) (cherry picked from commit 9cfa74a) remove reserved memory (#349) (cherry picked from commit 0015d0e) Add rocm_dev config for remote caching (#353) (cherry picked from commit c815420) added rocm7 support to EnablePeerAccess (#347) * added rocm7 support to EnablePeerAccess * use wrap namespace, clang-format and add comments (cherry picked from commit 85548a7) [ROCm] Disable Cudnn fusions (#358) (cherry picked from commit edab8b2)
1 parent 12a7305 commit 6743efd

36 files changed

+613
-188
lines changed

build_tools/rocm/run_xla_multi_gpu.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,4 +113,4 @@ bazel --bazelrc=build_tools/rocm/rocm_xla.bazelrc test \
113113
# clean up bazel disk_cache
114114
bazel shutdown \
115115
--disk_cache=${BAZEL_DISK_CACHE_DIR} \
116-
--experimental_disk_cache_gc_max_size=${BAZEL_DISK_CACHE_SIZE}
116+
--experimental_disk_cache_gc_max_size=${BAZEL_DISK_CACHE_SIZE}

tensorflow.bazelrc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,19 @@ build:rocm_clang_official --action_env=TF_ROCM_CLANG="1"
275275
build:rocm_clang_official --linkopt="-fuse-ld=lld"
276276
build:rocm_clang_official --host_linkopt="-fuse-ld=lld"
277277

278+
build:rocm_dev --remote_upload_local_results=false
279+
build:rocm_dev --remote_cache="https://wardite.cluster.engflow.com"
280+
281+
build:rocm_rbe --bes_backend="grpcs://wardite.cluster.engflow.com"
282+
build:rocm_rbe --bes_results_url="https://wardite.cluster.engflow.com/invocation/"
283+
build:rocm_rbe --remote_executor="grpcs://wardite.cluster.engflow.com"
284+
build:rocm_rbe --host_platform="//platform/linux_x64"
285+
build:rocm_rbe --extra_execution_platforms="//platform/linux_x64"
286+
build:rocm_rbe --platforms="//platform/linux_x64"
287+
build:rocm_rbe --bes_timeout=600s
288+
build:rocm_rbe --tls_client_certificate="/tf/certificates/ci-cert.crt"
289+
build:rocm_rbe --tls_client_key="/tf/certificates/ci-cert.key"
290+
278291
build:rocm_ci --config=rocm_clang_official
279292

280293
build:rocm_ci_hermetic --config=rocm_clang_official

third_party/gpus/rocm/BUILD.tpl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,11 @@ cc_library(
171171
],
172172
strip_include_prefix = "%{rocm_root}",
173173
deps = [
174-
":amd_comgr",
175-
":hsa_rocr",
176174
":rocm_config",
177-
":rocm_smi",
178175
":rocprofiler_register",
179176
":system_libs",
180177
],
178+
visibility = ["//visibility:public"],
181179
)
182180

183181
# Used by jax_rocm_plugin to minimally link to hip runtime.
@@ -259,8 +257,8 @@ cc_library(
259257
cc_library(
260258
name = "miopen",
261259
hdrs = glob(["%{rocm_root}/include/miopen/**"]),
260+
srcs = glob(["%{rocm_root}/lib/libMIOpen*.so*"]),
262261
data = glob([
263-
"%{rocm_root}/lib/libMIOpen*.so*",
264262
"%{rocm_root}/share/miopen/**",
265263
]),
266264
include_prefix = "rocm",

third_party/rocm_device_libs/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# copybara:uncomment package(default_applicable_licenses = ["//third_party/tensorflow:license"])
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
load("@bazel_skylib//lib:paths.bzl", "paths")
2+
3+
def bitcode_library(
4+
name,
5+
srcs = [],
6+
hdrs = [],
7+
file_specific_flags = {},
8+
**kwargs
9+
):
10+
"""Builds a bitcode library
11+
12+
Args:
13+
name: Unique name of the build rule.
14+
srcs: List of source files (*.cl, *.ll).
15+
hdrs: List of header files (*.h).
16+
file_specific_flags: Per-file dict of flags to be passed to clang.
17+
**kwargs: Attributes relevant for a common rule.
18+
"""
19+
clang_tool = "@llvm-project//clang:clang"
20+
clang_include = "@llvm-raw//:clang/lib/Headers"
21+
llvm_link_tool = "@llvm-project//llvm:llvm-link"
22+
opt_tool = "@llvm-project//llvm:opt"
23+
prepare_builtins_tool = ":prepare_builtins"
24+
25+
include_paths = dict([(paths.dirname(h), None) for h in hdrs]).keys()
26+
includes = " ".join(["-I$(location {})".format(inc) for inc in include_paths])
27+
flags = ("-fcolor-diagnostics -Werror -Wno-error=atomic-alignment -x cl -Xclang " +
28+
"-cl-std=CL2.0 --target=amdgcn-amd-amdhsa -fvisibility=hidden -fomit-frame-pointer " +
29+
"-Xclang -finclude-default-header -Xclang -fexperimental-strict-floating-point " +
30+
"-Xclang -fdenormal-fp-math=dynamic -Xclang -Qn " +
31+
"-nogpulib -cl-no-stdinc -Xclang -mcode-object-version=none")
32+
33+
link_inputs = []
34+
35+
for src in srcs:
36+
filename = paths.basename(src)
37+
(basename, _, ext) = filename.partition(".")
38+
39+
if (ext == "ll"):
40+
link_inputs.append(src)
41+
continue
42+
43+
out = basename + ".bc"
44+
link_inputs.append(out)
45+
extra_flags = " ".join(file_specific_flags.get(filename,[]))
46+
native.genrule(
47+
name = "compile_" + basename,
48+
srcs = [src] + hdrs + include_paths,
49+
outs = [out],
50+
#TODO(rocm): Ugly hack to access bultin clang includes.
51+
cmd = "$(location {}) -I$(execpath {}).runfiles/llvm-project/clang/staging/include/ {} {} {} -emit-llvm -c $(location {}) -o $@".format(
52+
clang_tool, clang_tool, includes, flags, extra_flags, src),
53+
tools = [clang_tool],
54+
message = "Compiling {} ...".format(filename),
55+
)
56+
57+
link_message = "Linking {}.bc ...".format(name)
58+
59+
prelink_out = name + ".link0.lib.bc"
60+
native.genrule(
61+
name = "prelink_" + name,
62+
srcs = link_inputs,
63+
outs = [prelink_out],
64+
cmd = "$(location {}) $(SRCS) -o $@".format(llvm_link_tool),
65+
tools = [llvm_link_tool],
66+
message = link_message,
67+
)
68+
69+
internalize_out = name + ".lib.bc"
70+
native.genrule(
71+
name = "internalize_" + name,
72+
srcs = [prelink_out],
73+
outs = [internalize_out],
74+
cmd = "$(location {}) -internalize -only-needed $< -o $@".format(llvm_link_tool),
75+
tools = [llvm_link_tool],
76+
message = link_message,
77+
)
78+
79+
strip_out = name + ".strip.bc"
80+
native.genrule(
81+
name = "strip_" + name,
82+
srcs = [internalize_out],
83+
outs = [strip_out],
84+
cmd = "$(location {}) -passes=amdgpu-unify-metadata,strip -o $@ $<".format(opt_tool),
85+
tools = [opt_tool],
86+
message = link_message,
87+
)
88+
89+
native.genrule(
90+
name = name,
91+
srcs = [strip_out],
92+
outs = [name + ".bc"],
93+
cmd = "$(location {}) -o $@ $<".format(prepare_builtins_tool),
94+
tools = [prepare_builtins_tool],
95+
message = link_message,
96+
)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
diff --git a/utils/prepare-builtins/prepare-builtins.cpp b/utils/prepare-builtins/prepare-builtins.cpp
2+
index 7fc9d06dab7d..2a93638c3f8f 100644
3+
--- a/utils/prepare-builtins/prepare-builtins.cpp
4+
+++ b/utils/prepare-builtins/prepare-builtins.cpp
5+
@@ -73,6 +73,13 @@ int main(int argc, char **argv) {
6+
return 1;
7+
}
8+
9+
+ // Strip the OpenCL version metadata. There are a lot of linked
10+
+ // modules in the library build, each spamming the same
11+
+ // version. This may also report a different version than the user
12+
+ // program is using. This should probably be uniqued when linking.
13+
+ if (NamedMDNode *OCLVersion = M->getNamedMetadata("opencl.ocl.version"))
14+
+ M->eraseNamedMetadata(OCLVersion);
15+
+
16+
// Set linkage of every external definition to linkonce_odr.
17+
for (Module::iterator i = M->begin(), e = M->end(); i != e; ++i) {
18+
if (!i->isDeclaration() && i->getLinkage() == GlobalValue::ExternalLinkage) {
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
load("build_defs.bzl", "bitcode_library")
2+
3+
licenses(["notice"])
4+
5+
package(default_visibility = ["//visibility:public"])
6+
7+
exports_files([
8+
"LICENSE.TXT",
9+
])
10+
11+
cc_binary(
12+
name = "prepare_builtins",
13+
srcs = glob([
14+
"utils/prepare-builtins/*.cpp",
15+
"utils/prepare-builtins/*.h",
16+
]),
17+
copts = [
18+
"-fno-rtti -fno-exceptions",
19+
],
20+
deps = [
21+
"@llvm-project//llvm:BitReader",
22+
"@llvm-project//llvm:BitWriter",
23+
"@llvm-project//llvm:Core",
24+
"@llvm-project//llvm:IRReader",
25+
"@llvm-project//llvm:Support",
26+
],
27+
visibility = ["//visibility:private"],
28+
)
29+
30+
bitcode_library(
31+
name = "ocml",
32+
srcs = glob([
33+
"ocml/src/*.cl"
34+
]),
35+
hdrs = glob([
36+
"ocml/src/*.h",
37+
"ocml/inc/*.h",
38+
"irif/inc/*.h",
39+
"oclc/inc/*.h",
40+
]),
41+
file_specific_flags = {
42+
"native_logF.cl": ["-fapprox-func"],
43+
"native_expF.cl": ["-fapprox-func"],
44+
"sqrtF.cl": ["-cl-fp32-correctly-rounded-divide-sqrt"],
45+
},
46+
)
47+
48+
bitcode_library(
49+
name = "ockl",
50+
srcs = glob([
51+
"ockl/src/*.cl",
52+
"ockl/src/*.ll",
53+
]),
54+
hdrs = glob([
55+
"ockl/inc/*.h",
56+
"irif/inc/*.h",
57+
"oclc/inc/*.h",
58+
]),
59+
file_specific_flags = {
60+
"gaaf.cl": ["-munsafe-fp-atomics"],
61+
},
62+
)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""Provides the repository macro to import Rocm-Device-Libs"""
2+
3+
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
4+
5+
def repo():
6+
"""Imports Rocm-Device-Libs."""
7+
LLVM_COMMIT = "0cf1859d038376421b4cd597e3df90d37cfca06e"
8+
LLVM_SHA256 = "0374d1efa0f049d2d1c24c4d86029b006cb5594cc0a1b6a18c49fb094c29cd29"
9+
10+
tf_http_archive(
11+
name = "rocm_device_libs",
12+
sha256 = LLVM_SHA256,
13+
strip_prefix = "llvm-project-{commit}/amd/device-libs".format(commit = LLVM_COMMIT),
14+
urls = tf_mirror_urls("https://github.com/ROCm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)),
15+
build_file = "//third_party/rocm_device_libs:rocm_device_libs.BUILD",
16+
patch_file = [
17+
"//third_party/rocm_device_libs:prepare_builtins.patch",
18+
],
19+
link_files = {
20+
"//third_party/rocm_device_libs:build_defs.bzl": "build_defs.bzl",
21+
},
22+
)

workspace2.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ load("//third_party/py:python_configure.bzl", "python_configure")
3030
load("//third_party/py/ml_dtypes:workspace.bzl", ml_dtypes = "repo")
3131
load("//third_party/pybind11_abseil:workspace.bzl", pybind11_abseil = "repo")
3232
load("//third_party/pybind11_bazel:workspace.bzl", pybind11_bazel = "repo")
33+
load("//third_party/rocm_device_libs:workspace.bzl", rocm_device_libs = "repo")
3334
load("//third_party/robin_map:workspace.bzl", robin_map = "repo")
3435
load("//third_party/shardy:workspace.bzl", shardy = "repo")
3536
load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo")
@@ -66,6 +67,7 @@ def _initialize_third_party():
6667
nvshmem()
6768
pybind11_abseil()
6869
pybind11_bazel()
70+
rocm_device_libs()
6971
robin_map()
7072
shardy()
7173
stablehlo()

xla/backends/gpu/codegen/emitters/reduction.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class ReductionFusion : public EmitterBase {
121121
return IndexingMap::GetUndefined();
122122
}
123123

124-
int64_t WarpSize() const {
124+
virtual int64_t WarpSize() const {
125125
return ::xla::gpu::WarpSize(analysis_.device_info());
126126
}
127127

@@ -198,6 +198,11 @@ class ColumnReductionFusion : public ReductionFusion {
198198
public:
199199
explicit ColumnReductionFusion(const HloFusionAnalysis& analysis);
200200

201+
int64_t WarpSize() const override {
202+
// PAE HACK HACK
203+
return 32;
204+
}
205+
201206
protected:
202207
llvm::SmallVector<mlir::Value> EmitReduction(
203208
int group_id, EmitterState& state) const override;
@@ -218,6 +223,11 @@ class SmallColumnReductionFusion : public ReductionFusion {
218223
public:
219224
explicit SmallColumnReductionFusion(const HloFusionAnalysis& analysis);
220225

226+
int64_t WarpSize() const override {
227+
// PAE HACK HACK
228+
return 32;
229+
}
230+
221231
protected:
222232
llvm::SmallVector<mlir::Value> EmitReduction(
223233
int group_id, EmitterState& state) const override;

0 commit comments

Comments
 (0)