Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 76 additions & 29 deletions xla/backends/gpu/runtime/buffer_comparator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ limitations under the License.
#include "tsl/platform/ml_dtypes.h"
#include "tsl/platform/test.h"

#include <hip/hip_runtime_api.h>

namespace xla {
namespace gpu {
namespace {
Expand Down Expand Up @@ -393,37 +395,82 @@ TEST_F(BufferComparatorTest, BF16) {
.value());
}

TEST_F(BufferComparatorTest, VeryLargeArray) {
constexpr PrimitiveType number_type = U16;
using NT = primitive_util::PrimitiveTypeToNative< number_type >::type;
// ROCm-only: CI-safe very-large compare using *device* memory and U8 type.
// - No HMM / mapped host memory
// - 64-bit indexing stress (N > 2^31)
// - Alignment-friendly aliasing offset
TEST_F(BufferComparatorTest, VeryLargeArray_Device_U8_Aligned) {
// Force 64-bit indexing without exceeding ~4 GiB VRAM.
constexpr PrimitiveType number_type = U8;
using NT = primitive_util::PrimitiveTypeToNative<number_type>::type;
static_assert(sizeof(NT) == 1, "This test expects 8-bit elements.");

// N just above 2^31 so 32-bit indexing overflows.
const uint64_t element_count = (1ull << 31) + 4096; // 2,147,488,744
// Shift rhs by a cacheline (64 B) to keep accesses aligned even if
// vectorized.
constexpr size_t kShiftBytes = 64;

// Total device allocation: N + shift (≈ 2.15 GiB)
const size_t bytes_total = static_cast<size_t>(element_count) + kShiftBytes;

// Create stream.
auto stream_or = stream_exec_->CreateStream();
ASSERT_TRUE(stream_or.ok()) << stream_or.status();
std::unique_ptr<se::Stream> stream = std::move(stream_or.value());

// Single device allocation of (N + shift) bytes.
void* dev_ptr = nullptr;
hipError_t herr = hipMalloc(&dev_ptr, bytes_total);
if (herr == hipErrorMemoryAllocation) {
GTEST_SKIP() << "Insufficient VRAM for ~"
<< (bytes_total / (1024.0 * 1024 * 1024))
<< " GiB device alloc.";
}
ASSERT_EQ(herr, hipSuccess) << "hipMalloc failed: " << static_cast<int>(herr);

// Views:
// base: whole allocation
// lhs : [0 .. N-1]
// rhs : [shift .. shift + N-1]
se::DeviceMemoryBase base(dev_ptr, bytes_total);
se::DeviceMemoryBase lhs(dev_ptr,
static_cast<size_t>(element_count)); // bytes for U8
se::DeviceMemoryBase rhs(static_cast<char*>(dev_ptr) + kShiftBytes,
lhs.size());

// Initialize with zeros (MemZero works for any size/alignment).
ASSERT_TRUE(stream->MemZero(&base, bytes_total).ok());
ASSERT_TRUE(stream->BlockHostUntilDone().ok());

// Comparator: force device path (exercise kernel indexing).
BufferComparator comparator(
ShapeUtil::MakeShape(number_type, {static_cast<int64_t>(element_count)}),
/*tolerance=*/0.0, /*verbose=*/false, /*run_host_compare=*/false);

// Pass 1: both views read zeros -> equal.
{
auto eq_or = comparator.CompareEqual(stream.get(), lhs, rhs);
ASSERT_TRUE(eq_or.ok()) << eq_or.status();
EXPECT_TRUE(eq_or.value());
}

// Flip the very last element of rhs via stream-sequenced memcpy
// (host->device).
const NT new_val = static_cast<NT>(0xA5);
se::DeviceMemoryBase rhs_last(
static_cast<char*>(rhs.opaque()) + (element_count - 1), sizeof(NT));
ASSERT_TRUE(stream->Memcpy(&rhs_last, &new_val, sizeof(NT)).ok());
ASSERT_TRUE(stream->BlockHostUntilDone().ok());

// Set non-power-of-two element count on purpose
int64_t element_count = (1LL << 34) - 11;
auto stream = stream_exec_->CreateStream().value();
// Pass 2: must detect tail mismatch.
{
auto eq_or = comparator.CompareEqual(stream.get(), lhs, rhs);
ASSERT_TRUE(eq_or.ok()) << eq_or.status();
EXPECT_FALSE(eq_or.value());
}

// Use host memory here since there is a limitation of 4GB per test on
// device memory alloc
TF_ASSERT_OK_AND_ASSIGN(
auto base,
stream_exec_->HostMemoryAllocate((element_count + 1) * sizeof(NT)));

// We use overlapping lhs and rhs arrays to reduce memory usage, also this
// serves as an extra test for possible pointer aliasing problems
se::DeviceMemoryBase lhs(base->opaque(), base->size() - sizeof(NT)),
rhs(static_cast< NT *>(base->opaque()) + 1, lhs.size());

TF_CHECK_OK(stream->Memset32(&lhs, 0xABCDABCD, base->size()));

// Disable host comparison here since it could take a while for ~8GB array
BufferComparator comparator(ShapeUtil::MakeShape(number_type, {element_count}),
/*tolerance*/0.1, /* verbose */false, /*run_host_compare*/false);
EXPECT_TRUE(comparator.CompareEqual(stream.get(), lhs, rhs).value());

// Change only the very last entry of rhs to verify that the whole arrays are
// compared (if the grid dimensions are not computed correctly, this might
// not be the case)
*(static_cast< NT *>(rhs.opaque()) + element_count - 1) = 1777;
EXPECT_FALSE(comparator.CompareEqual(stream.get(), lhs, rhs).value());
ASSERT_EQ(hipFree(dev_ptr), hipSuccess);
}

} // namespace
Expand Down
182 changes: 167 additions & 15 deletions xla/backends/profiler/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -217,20 +217,80 @@ tsl_gpu_library(
],
)

tsl_gpu_library(
config_setting(
name = "use_v1",
values = {"define": "xla_rocm_profiler=v1"},
)

config_setting(
name = "use_rocprofiler_sdk",
values = {"define": "xla_rocm_profiler=v3"},
)

cc_library(
name = "rocm_profiler_backend_cfg",
defines = select({
":use_v1": ["XLA_GPU_ROCM_TRACER_BACKEND=1"],
":use_rocprofiler_sdk": ["XLA_GPU_ROCM_TRACER_BACKEND=3"],
"//conditions:default": ["XLA_GPU_ROCM_TRACER_BACKEND=3"],
}),
visibility = ["//visibility:public"],
)

cc_library(
name = "rocm_tracer_utils",
srcs = ["rocm_tracer_utils.cc"],
hdrs = ["rocm_tracer_utils.h"],
tags = [
"gpu",
"manual",
"rocm-only",
],
linkopts = [
"-Wl,--no-as-needed",
"-L/opt/rocm/lib",
"-Wl,-rpath,/opt/rocm/lib",
"-lhsa-runtime64",
],
deps = [
"//xla/tsl/profiler/backends/cpu:annotation_stack",
"//xla/tsl/profiler/utils:time_utils",
"//xla/tsl/profiler/utils:math_utils",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/container:node_hash_set",
"@tsl//tsl/platform:env_time",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:macros",
],
visibility = ["//visibility:public"],
)

cc_library(
name = "rocm_collector",
srcs = if_rocm(["rocm_collector.cc"]),
hdrs = if_rocm(["rocm_collector.h"]),
copts = tf_profiler_copts() + tsl_copts(),
srcs = ["rocm_collector.cc"],
hdrs = ["rocm_collector.h"],
# copybara:uncomment compatible_with = ["//buildenv/target:non_prod"],
linkopts = select({
"//conditions:default": [
"-L/opt/rocm/lib", # search path for all ROCm shared objects
"-lrocprofiler-sdk", # the library that owns the missing symbols
],
}),
tags = [
"gpu",
"rocm-only",
] + if_google([
# TODO(b/360374983): Remove this tag once the target can be built without --config=rocm.
"manual",
]),
visibility = ["//visibility:public"],
deps = [
":rocm_tracer_utils",
":rocm_profiler_backend_cfg",
"//xla/stream_executor/rocm:roctracer_wrapper",
"//xla/tsl/profiler/backends/cpu:annotation_stack",
"//xla/tsl/profiler/utils:parse_annotation",
Expand All @@ -245,11 +305,11 @@ tsl_gpu_library(
"@com_google_absl//absl/container:node_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@tsl//tsl/platform:abi",
"@tsl//tsl/platform:env_time",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:macros",
"@tsl//tsl/platform:mutex",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:thread_annotations",
"@tsl//tsl/platform:types",
Expand All @@ -258,38 +318,130 @@ tsl_gpu_library(
],
)

tsl_gpu_library(
name = "rocm_tracer",
srcs = if_rocm(["rocm_tracer.cc"]),
hdrs = if_rocm(["rocm_tracer.h"]),
copts = tf_profiler_copts() + tsl_copts(),
cc_library(
name = "rocm_tracer_headers",
hdrs = [
"rocm_tracer.h",
"rocm_profiler_sdk.h",
"rocm_tracer_v1.h",
],
tags = [
"gpu",
"rocm-only",
] + if_google([
# TODO(b/360374983): Remove this tag once the target can be built without --config=rocm.
"manual",
]),
"rocm-only",
],
# PROPAGATE the layout macro to every dependent TU:
defines = select({
":use_v1": ["XLA_GPU_ROCM_TRACER_BACKEND=1"],
":use_rocprofiler_sdk": ["XLA_GPU_ROCM_TRACER_BACKEND=3"],
"//conditions:default": ["XLA_GPU_ROCM_TRACER_BACKEND=3"],
}),
visibility = ["//visibility:public"],
)

cc_library(
name = "rocm_tracer_impl",
srcs = select({
":use_v1": ["rocm_tracer_v1.cc"],
":use_rocprofiler_sdk": ["rocm_profiler_sdk.cc"],
"//conditions:default": ["rocm_profiler_sdk.cc"],
}),
tags = [
"gpu",
"manual",
"rocm-only",
],
deps = [
":rocm_tracer_headers",
":rocm_collector",
"//xla/stream_executor/rocm:roctracer_wrapper",
"//xla/tsl/profiler/backends/cpu:annotation_stack",
"//xla/tsl/profiler/utils:time_utils",
"//xla/tsl/profiler/utils:xplane_builder",
"//xla/tsl/profiler/utils:xplane_schema",
"//xla/tsl/profiler/utils:xplane_utils",
"//xla/tsl/util:env_var",
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/container:node_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/synchronization",
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:macros",
"@tsl//tsl/platform:platform_port",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:thread_annotations",
"@tsl//tsl/platform:types",
"@tsl//tsl/profiler/lib:profiler_factory",
"@tsl//tsl/profiler/lib:profiler_interface",
],
)

cc_library(
name = "rocm_tracer",
tags = [
"gpu",
"manual",
"rocm-only",
],
deps = [":rocm_tracer_headers", ":rocm_tracer_impl"],
visibility = ["//visibility:public"],
)

# upstream it's called xla_cc_test as no GPU involved.
xla_test(
name = "rocm_tracer_test",
size = "small",
srcs = ["rocm_tracer_test.cc"],
tags = [
"gpu",
"rocm-only",
] + if_google([
# Optional: only run internally if ROCm config is enabled
"manual",
]),
deps = [
":rocm_tracer",
":rocm_tracer_utils",
"//xla/tsl/profiler/utils:xplane_builder",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:status_matchers",
"@tsl//tsl/platform:test",
"@tsl//tsl/profiler/protobuf:xplane_proto_cc",
],
)

xla_test(
name = "rocm_collector_test",
size = "small",
srcs = ["rocm_collector_test.cc"],
tags = [
"gpu",
"rocm-only",
] + if_google([
"manual",
]),
deps = [
":rocm_collector",
":rocm_tracer_utils",
"//xla/tsl/profiler/utils:xplane_builder",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:env_time",
"@tsl//tsl/platform:status_matchers",
"@tsl//tsl/platform:test",
"@tsl//tsl/profiler/protobuf:xplane_proto_cc",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:macros",
],
)

tsl_gpu_library(
Expand Down
Loading