Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build_tools/lint/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
"multi_gpu_h100": (
"Used by `xla_test` to signal that multiple H100s are needed."
),
"skip_rocprofiler_sdk": "used to skip rocmtracer test as it calls rocprofiler-sdk via rocprofiler_force_configure",
}


Expand Down
3 changes: 2 additions & 1 deletion build_tools/rocm/run_xla.sh
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ BasicDotAlgorithmEmitterTestSuite/BasicDotAlgorithmEmitterTest.BasicAlgorithmIsE
CommandBufferTests/CommandBufferTest.IndexConditional/*
CommandBufferTests/CommandBufferTest.WhileLoop/*
CommandBufferTests/CommandBufferTest.TrueFalseConditional/*
BufferComparatorTest.VeryLargeArray_Device_U8_Aligned
)

BAZEL_DISK_CACHE_SIZE=100G
Expand Down Expand Up @@ -147,4 +148,4 @@ bazel --bazelrc=build_tools/rocm/rocm_xla.bazelrc test \
# clean up bazel disk_cache
bazel shutdown \
--disk_cache=${BAZEL_DISK_CACHE_DIR} \
--experimental_disk_cache_gc_max_size=${BAZEL_DISK_CACHE_SIZE}
--experimental_disk_cache_gc_max_size=${BAZEL_DISK_CACHE_SIZE}
13 changes: 13 additions & 0 deletions third_party/gpus/rocm/BUILD.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,19 @@ cc_library(
deps = [":rocm_config"],
)

cc_library(
name = "rocprofiler-sdk",
srcs = glob(["%{rocm_root}/lib/librocprofiler-sdk*.so*"]),
hdrs = glob(["%{rocm_root}/include/rocprofiler-sdk/**"]),
include_prefix = "rocm",
includes = [
"%{rocm_root}/include/",
],
strip_include_prefix = "%{rocm_root}",
visibility = ["//visibility:public"],
deps = [":rocm_config"],
)

cc_library(
name = "rocsolver",
srcs = glob(["%{rocm_root}/lib/librocsolver*.so*"]),
Expand Down
1 change: 1 addition & 0 deletions third_party/gpus/rocm_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def _find_libs(repository_ctx, rocm_config, miopen_path, rccl_path, bash_bin):
("rocsolver", rocm_config.rocm_toolkit_path),
("hipfft", rocm_config.rocm_toolkit_path),
("rocrand", rocm_config.rocm_toolkit_path),
("rocprofiler-sdk", rocm_config.rocm_toolkit_path),
]
]
if int(rocm_config.rocm_version_number) >= 40500:
Expand Down
80 changes: 80 additions & 0 deletions xla/backends/gpu/runtime/buffer_comparator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ limitations under the License.
#include "tsl/platform/ml_dtypes.h"
#include "tsl/platform/test.h"

#include <hip/hip_runtime_api.h>

namespace xla {
namespace gpu {
namespace {
Expand Down Expand Up @@ -403,6 +405,84 @@ TEST_F(BufferComparatorTest, BF16) {
.value());
}

// ROCm-only: CI-safe very-large compare using *device* memory and U8 type.
// - No HMM / mapped host memory
// - 64-bit indexing stress (N > 2^31)
// - Alignment-friendly aliasing offset
TEST_F(BufferComparatorTest, VeryLargeArray_Device_U8_Aligned) {
// Force 64-bit indexing without exceeding ~4 GiB VRAM.
constexpr PrimitiveType number_type = U8;
using NT = primitive_util::PrimitiveTypeToNative<number_type>::type;
static_assert(sizeof(NT) == 1, "This test expects 8-bit elements.");

// N just above 2^31 so 32-bit indexing overflows.
const uint64_t element_count = (1ull << 31) + 4096; // 2,147,488,744
// Shift rhs by a cacheline (64 B) to keep accesses aligned even if
// vectorized.
constexpr size_t kShiftBytes = 64;

// Total device allocation: N + shift (≈ 2.15 GiB)
const size_t bytes_total = static_cast<size_t>(element_count) + kShiftBytes;

// Create stream.
auto stream_or = stream_exec_->CreateStream();
ASSERT_TRUE(stream_or.ok()) << stream_or.status();
std::unique_ptr<se::Stream> stream = std::move(stream_or.value());

// Single device allocation of (N + shift) bytes.
void* dev_ptr = nullptr;
hipError_t herr = hipMalloc(&dev_ptr, bytes_total);
if (herr == hipErrorMemoryAllocation) {
GTEST_SKIP() << "Insufficient VRAM for ~"
<< (bytes_total / (1024.0 * 1024 * 1024))
<< " GiB device alloc.";
}
ASSERT_EQ(herr, hipSuccess) << "hipMalloc failed: " << static_cast<int>(herr);

// Views:
// base: whole allocation
// lhs : [0 .. N-1]
// rhs : [shift .. shift + N-1]
se::DeviceMemoryBase base(dev_ptr, bytes_total);
se::DeviceMemoryBase lhs(dev_ptr,
static_cast<size_t>(element_count)); // bytes for U8
se::DeviceMemoryBase rhs(static_cast<char*>(dev_ptr) + kShiftBytes,
lhs.size());

// Initialize with zeros (MemZero works for any size/alignment).
ASSERT_TRUE(stream->MemZero(&base, bytes_total).ok());
ASSERT_TRUE(stream->BlockHostUntilDone().ok());

// Comparator: force device path (exercise kernel indexing).
BufferComparator comparator(
ShapeUtil::MakeShape(number_type, {static_cast<int64_t>(element_count)}),
/*tolerance=*/0.0, /*verbose=*/false);

// Pass 1: both views read zeros -> equal.
{
auto eq_or = comparator.CompareEqual(stream.get(), lhs, rhs);
ASSERT_TRUE(eq_or.ok()) << eq_or.status();
EXPECT_TRUE(eq_or.value());
}

// Flip the very last element of rhs via stream-sequenced memcpy
// (host->device).
const NT new_val = static_cast<NT>(0xA5);
se::DeviceMemoryBase rhs_last(
static_cast<char*>(rhs.opaque()) + (element_count - 1), sizeof(NT));
ASSERT_TRUE(stream->Memcpy(&rhs_last, &new_val, sizeof(NT)).ok());
ASSERT_TRUE(stream->BlockHostUntilDone().ok());

// Pass 2: must detect tail mismatch.
{
auto eq_or = comparator.CompareEqual(stream.get(), lhs, rhs);
ASSERT_TRUE(eq_or.ok()) << eq_or.status();
EXPECT_FALSE(eq_or.value());
}

ASSERT_EQ(hipFree(dev_ptr), hipSuccess);
}

} // namespace
} // namespace gpu
} // namespace xla
157 changes: 149 additions & 8 deletions xla/backends/profiler/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,58 @@ cc_library(
],
)

config_setting(
name = "use_v1",
values = {"define": "xla_rocm_profiler=v1"},
)

config_setting(
name = "use_rocprofiler_sdk",
values = {"define": "xla_rocm_profiler=v3"},
)

cc_library(
name = "rocm_profiler_backend_cfg",
defines = select({
":use_v1": ["XLA_GPU_ROCM_TRACER_BACKEND=1"],
":use_rocprofiler_sdk": ["XLA_GPU_ROCM_TRACER_BACKEND=3"],
"//conditions:default": ["XLA_GPU_ROCM_TRACER_BACKEND=3"],
}),
visibility = ["//visibility:public"],
)

cc_library(
name = "rocm_tracer_utils",
srcs = ["rocm_tracer_utils.cc"],
hdrs = ["rocm_tracer_utils.h"],
tags = [
"gpu",
"manual",
"rocm-only",
],
deps = [
"//xla/tsl/profiler/backends/cpu:annotation_stack",
"//xla/tsl/profiler/utils:time_utils",
"//xla/tsl/profiler/utils:math_utils",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/container:node_hash_set",
"@tsl//tsl/platform:env_time",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:macros",
"@local_config_rocm//rocm:rocprofiler-sdk",
],
visibility = ["//visibility:public"],
)

cc_library(
name = "rocm_collector",
srcs = ["rocm_collector.cc"],
hdrs = ["rocm_collector.h"],
# copybara:uncomment compatible_with = ["//buildenv/target:non_prod"],
tags = [
"gpu",
"rocm-only",
Expand All @@ -372,6 +419,8 @@ cc_library(
"manual",
]),
deps = [
":rocm_tracer_utils",
":rocm_profiler_backend_cfg",
"//xla/stream_executor/rocm:roctracer_wrapper",
"//xla/tsl/profiler/backends/cpu:annotation_stack",
"//xla/tsl/profiler/utils:parse_annotation",
Expand All @@ -396,26 +445,53 @@ cc_library(
"@tsl//tsl/platform:types",
"@tsl//tsl/profiler/lib:profiler_factory",
"@tsl//tsl/profiler/lib:profiler_interface",
"@local_config_rocm//rocm:rocprofiler-sdk",
],
)

cc_library(
name = "rocm_tracer",
srcs = ["rocm_tracer.cc"],
hdrs = ["rocm_tracer.h"],
# copybara:uncomment compatible_with = ["//buildenv/target:non_prod"],
name = "rocm_tracer_headers",
hdrs = [
"rocm_tracer.h",
"rocm_profiler_sdk.h",
"rocm_tracer_v1.h",
],
tags = [
"gpu",
"manual",
"rocm-only",
] + if_google([
# TODO(b/360374983): Remove this tag once the target can be built without --config=rocm.
],
# PROPAGATE the layout macro to every dependent TU:
defines = select({
":use_v1": ["XLA_GPU_ROCM_TRACER_BACKEND=1"],
":use_rocprofiler_sdk": ["XLA_GPU_ROCM_TRACER_BACKEND=3"],
"//conditions:default": ["XLA_GPU_ROCM_TRACER_BACKEND=3"],
}),
visibility = ["//visibility:public"],
)

cc_library(
name = "rocm_tracer_impl",
srcs = select({
":use_v1": ["rocm_tracer_v1.cc"],
":use_rocprofiler_sdk": ["rocm_profiler_sdk.cc"],
"//conditions:default": ["rocm_profiler_sdk.cc"],
}),
tags = [
"gpu",
"manual",
]),
"rocm-only",
],
deps = [
":rocm_tracer_headers",
":rocm_collector",
"//xla/stream_executor/rocm:roctracer_wrapper",
"//xla/tsl/profiler/backends/cpu:annotation_stack",
"//xla/tsl/profiler/utils:time_utils",
"//xla/tsl/profiler/utils:xplane_builder",
"//xla/tsl/profiler/utils:xplane_schema",
"//xla/tsl/profiler/utils:xplane_utils",
"//xla/tsl/util:env_var",
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
Expand All @@ -432,9 +508,74 @@ cc_library(
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:thread_annotations",
"@tsl//tsl/platform:types",
"@tsl//tsl/profiler/lib:profiler_factory",
"@tsl//tsl/profiler/lib:profiler_interface",
],
)

cc_library(
name = "rocm_tracer",
tags = [
"gpu",
"manual",
"rocm-only",
],
deps = [":rocm_tracer_headers", ":rocm_tracer_impl"],
visibility = ["//visibility:public"],
)

# upstream it's called xla_cc_test as no GPU involved.
xla_test(
name = "rocm_tracer_test",
size = "small",
srcs = ["rocm_tracer_test.cc"],
tags = [
"gpu",
"rocm-only",
"skip_rocprofiler_sdk", # due to rocprofiler-sdk's rocprofiler_force_configure
] + if_google([
# Optional: only run internally if ROCm config is enabled
"manual",
]),
deps = [
":rocm_tracer",
":rocm_tracer_utils",
"//xla/tsl/profiler/utils:xplane_builder",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:status_matchers",
"@tsl//tsl/platform:test",
"@tsl//tsl/profiler/protobuf:xplane_proto_cc",
],
)

xla_test(
name = "rocm_collector_test",
size = "small",
srcs = ["rocm_collector_test.cc"],
tags = [
"gpu",
"rocm-only",
] + if_google([
"manual",
]),
deps = [
":rocm_collector",
":rocm_tracer_utils",
"//xla/tsl/profiler/utils:xplane_builder",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:env_time",
"@tsl//tsl/platform:status_matchers",
"@tsl//tsl/platform:test",
"@tsl//tsl/profiler/protobuf:xplane_proto_cc",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:macros",
],
)

cc_library(
name = "nvtx_utils",
srcs = ["nvtx_utils.cc"],
Expand Down
Loading