Skip to content

Commit fbce546

Browse files
hsharshaGoogle-ML-Automation
authored andcommitted
PR #36046: [ROCm] Fix failing unit tests on ROCm platform
Imported from GitHub PR #36046 📝 Summary of Changes - layout_assignment tests are marked cuda-only. - sample_file_test needs higher autotuner level for MIOpen to return conv algorithm. Earlier this was coming from GetDebugOptionsForTest. - buffer_debug_log test is made gpu agnostic by using cannonical gpu name. - cublas_gemm_rewriter_test_amdgpu_any fix unit test to remove padding for ROCm as introduced in #33854 - gpu_kernel_tiling_test_amdgpu_any is updated to respect higher launch dimensions now supported by hipruntime - Mark dynamic_shared_memory_test as cuda-only - Add arch specific checks for barriers to sorting.hlo 🎯 Justification Fixes failing unit tests on ROCm platform 🚀 Kind of Contribution 🐛 Bug Fix, 🧪 Tests Copybara import of the project: -- 472cd54 by Harsha HS <Harsha.HavanurShamsundara@amd.com>: [ROCm] Fix failing unit tests on ROCm platform - layout_assignment tests are marked cuda-only. - sample_file_test needs higher autotuner level for MIOpen to return conv algorithm. Earlier this was coming from GetDebugOptionsForTest. - buffer_debug_log test is made gpu agnostic by using cannonical gpu name. -- 3bb9422 by Harsha HS <Harsha.HavanurShamsundara@amd.com>: Fix tests which started to fail due to #33854 -- 850d955 by Harsha HS <Harsha.HavanurShamsundara@amd.com>: HIP now respects highter launch dimension similar to CUDA -- b504a7e by Harsha HS <Harsha.HavanurShamsundara@amd.com>: Make dynamic_shared_memory_test cuda only -- 1e4e57a by Harsha HS <Harsha.HavanurShamsundara@amd.com>: Add arch specific checks to sorting.hlo -- ce1241c by Harsha HS <Harsha.HavanurShamsundara@amd.com>: Address review comments Merging this change closes #36046 FUTURE_COPYBARA_INTEGRATE_REVIEW=#36046 from ROCm:ci_fix_upstream_ut_20260107 ce1241c PiperOrigin-RevId: 855607651
1 parent a6da902 commit fbce546

9 files changed

Lines changed: 126 additions & 50 deletions

File tree

xla/service/gpu/tests/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,7 @@ xla_test(
836836
name = "dynamic_shared_memory_test",
837837
srcs = if_cuda_is_configured(["dynamic_shared_memory_test.cc"]),
838838
backends = ["gpu"],
839+
tags = ["cuda-only"],
839840
deps = [
840841
"//xla:shape_util",
841842
"//xla:types",

xla/service/gpu/tests/gpu_kernel_tiling_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ TEST_F(GpuKernelTilingTest, ReductionInputTooLarge) {
420420
if (xla::PlatformUtil::CanonicalPlatformName("gpu").value() == "rocm") {
421421
EXPECT_THAT(status.message(),
422422
::testing::ContainsRegex(
423-
"Kernel '.*' launch needs more blocks [(]2147483648, 1[)] "
423+
"Kernel '.*' launch needs more blocks [(]4294967296, 1[)] "
424424
"than allowed by hardware [(]2147483647, 65536[)]"));
425425
} else {
426426
EXPECT_THAT(status.message(),

xla/service/gpu/tests/sorting.hlo

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -609,22 +609,26 @@ compare {
609609
// CHECK: %[[VAL_405:.*]] = icmp slt i64 %[[VAL_404]], 3
610610
// CHECK: br i1 %[[VAL_405]], label %[[VAL_406:.*]], label %[[VAL_407:.*]]
611611
// CHECK: smaller_keys_index-after29: ; preds = %[[VAL_406]], %[[VAL_403]]
612-
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
612+
// CHECK-PTX: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
613+
// CHECK-GCN: call void @llvm.amdgcn.s.barrier
613614
// CHECK: %[[VAL_408:.*]] = mul i64 %[[VAL_363]], 4
614615
// CHECK: %[[VAL_409:.*]] = icmp uge i64 %[[VAL_408]], 0
615616
// CHECK: br i1 %[[VAL_409]], label %[[VAL_410:.*]], label %[[VAL_411:.*]]
616617
// CHECK: is_last_tile-after: ; preds = %[[VAL_412:.*]], %[[VAL_413:.*]]
617-
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
618+
// CHECK-PTX: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
619+
// CHECK-GCN: call void @llvm.amdgcn.s.barrier
618620
// CHECK: %[[VAL_414:.*]] = mul i64 %[[VAL_363]], 4
619621
// CHECK: %[[VAL_415:.*]] = icmp uge i64 %[[VAL_414]], 0
620622
// CHECK: br i1 %[[VAL_415]], label %[[VAL_416:.*]], label %[[VAL_417:.*]]
621623
// CHECK: is_last_tile-after56: ; preds = %[[VAL_418:.*]], %[[VAL_419:.*]]
622-
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
624+
// CHECK-PTX: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
625+
// CHECK-GCN: call void @llvm.amdgcn.s.barrier
623626
// CHECK: %[[VAL_420:.*]] = mul i64 %[[VAL_363]], 4
624627
// CHECK: %[[VAL_421:.*]] = icmp uge i64 %[[VAL_420]], 0
625628
// CHECK: br i1 %[[VAL_421]], label %[[VAL_422:.*]], label %[[VAL_423:.*]]
626629
// CHECK: is_last_tile-after89: ; preds = %[[VAL_424:.*]], %[[VAL_425:.*]]
627-
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
630+
// CHECK-PTX: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
631+
// CHECK-GCN: call void @llvm.amdgcn.s.barrier
628632
// CHECK: %[[VAL_426:.*]] = mul nuw nsw i64 %[[VAL_363]], 4
629633
// CHECK: %[[VAL_427:.*]] = mul nuw nsw i64 %[[VAL_371]], 4
630634
// CHECK: %[[VAL_428:.*]] = add nuw nsw i64 %[[VAL_426]], 0
@@ -1251,10 +1255,10 @@ compare {
12511255
// CHECK: %[[VAL_843:.*]] = load float, ptr %[[VAL_844:.*]], align 4
12521256
// CHECK: %[[VAL_845:.*]] = fcmp olt float %[[VAL_841]], %[[VAL_843]]
12531257
// CHECK: %[[VAL_846:.*]] = zext i1 %[[VAL_845]] to i8
1254-
// CHECK: store i8 %[[VAL_846]], ptr %[[VAL_840]], align 1
1255-
// CHECK: %[[VAL_847:.*]] = load i8, ptr %[[VAL_840]], align 1
1256-
// CHECK: store i8 %[[VAL_847]], ptr %[[VAL_848:.*]], align 1
1257-
// CHECK: ret void
1258+
// CHECK-PTX: store i8 %[[VAL_846]], ptr %[[VAL_840]], align 1
1259+
// CHECK-PTX: %[[VAL_847:.*]] = load i8, ptr %[[VAL_840]], align 1
1260+
// CHECK-PTX: store i8 %[[VAL_847]], ptr %[[VAL_848:.*]], align 1
1261+
// CHECK-PTX: ret void
12581262

12591263
ENTRY main {
12601264
x = s32[2, 3] parameter(0)
@@ -1286,7 +1290,7 @@ ENTRY main {
12861290
ROOT sort = (f64[2, 2048], f64[2, 2048], f64[2, 2048], f64[2, 2048]) sort(param0, param1, param2, param3), dimensions={1}, to_apply=compare
12871291
}
12881292
// Check that we have a tile size of 1024.
1289-
// CHECK: getelementptr [1024 x double], ptr addrspace(3) @sort_tile_param_0
1293+
// CHECK-PTX: getelementptr [1024 x double], ptr addrspace(3) @sort_tile_param_0
12901294

12911295
// -----
12921296

@@ -1304,4 +1308,4 @@ ENTRY main {
13041308
}
13051309

13061310
// CHECK-COUNT-334: xor i64
1307-
// CHECK-NOT: xor i64
1311+
// CHECK-GCN: xor i64

xla/service/gpu/transforms/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ load("//xla/tests:build_defs.bzl", "xla_test")
99
load("//xla/tsl:tsl.bzl", "if_oss")
1010
load(
1111
"//xla/tsl/platform:build_config_root.bzl",
12-
"tf_gpu_tests_tags",
12+
"tf_cuda_tests_tags",
1313
)
1414
load(
1515
"//xla/tsl/platform/default:cuda_build_defs.bzl",
@@ -1846,7 +1846,7 @@ lit_test_suite(
18461846
),
18471847
cfg = "//xla:lit.cfg.py",
18481848
data = ["//xla/backends/gpu/target_config:all_gpu_specs"],
1849-
default_tags = tf_gpu_tests_tags(),
1849+
default_tags = ["cuda-only"] + tf_cuda_tests_tags(),
18501850
tools = [
18511851
"//xla/tools:hlo-opt",
18521852
"@llvm-project//llvm:FileCheck",

xla/service/gpu/transforms/cublas_gemm_rewriter_test.cc

Lines changed: 96 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1639,11 +1639,19 @@ ENTRY test {
16391639
})";
16401640

16411641
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
1642-
MatchOptimizedHlo(hlo_text, R"(
1643-
; CHECK-DAG: ENTRY %test ({{.*}}: bf16[2,3], {{.*}}: bf16[3,4], {{.*}}: bf16[4]) -> bf16[2,4] {
1644-
; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_6x0_5
1645-
; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_5x0_4
1642+
if (IsCuda()) {
1643+
MatchOptimizedHlo(hlo_text, R"(
1644+
; CHECK-DAG: ENTRY %test ({{.*}}: bf16[2,3], {{.*}}: bf16[3,4], {{.*}}: bf16[4]) -> bf16[2,4] {
1645+
; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_6x0_5
1646+
; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_5x0_4
16461647
)");
1648+
} else {
1649+
MatchOptimizedHlo(hlo_text, R"(
1650+
; CHECK-DAG: ENTRY %test ({{.*}}: bf16[2,3], {{.*}}: bf16[3,4], {{.*}}: bf16[4]) -> bf16[2,4] {
1651+
; CHECK-DAG: bf16[2,3]{1,0}
1652+
; CHECK-DAG: bf16[3,4]{1,0}
1653+
)");
1654+
}
16471655
}
16481656

16491657
TEST_F(CublasLtGemmRewriteTest, ReluActivation) {
@@ -2428,11 +2436,19 @@ ENTRY test {
24282436
})";
24292437

24302438
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{5e-5, 1e-5}));
2431-
MatchOptimizedHlo(hlo_text, R"(
2432-
; CHECK-DAG: ENTRY %test ({{.*}}: bf16[2,3], {{.*}}: bf16[3,4]) -> bf16[2,4] {
2433-
; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_6x0_5
2434-
; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_5x0_4
2439+
if (IsCuda()) {
2440+
MatchOptimizedHlo(hlo_text, R"(
2441+
; CHECK-DAG: ENTRY %test ({{.*}}: bf16[2,3], {{.*}}: bf16[3,4]) -> bf16[2,4] {
2442+
; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_6x0_5
2443+
; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_5x0_4
24352444
)");
2445+
} else {
2446+
MatchOptimizedHlo(hlo_text, R"(
2447+
; CHECK-DAG: ENTRY %test ({{.*}}: bf16[2,3], {{.*}}: bf16[3,4]) -> bf16[2,4] {
2448+
; CHECK-DAG: bf16[2,3]{1,0}
2449+
; CHECK-DAG: bf16[3,4]{1,0}
2450+
)");
2451+
}
24362452
}
24372453

24382454
TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationBitcast) {
@@ -2606,13 +2622,19 @@ ENTRY test {
26062622
})";
26072623

26082624
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
2609-
MatchOptimizedHlo(hlo_text,
2610-
R"(
2611-
2612-
; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6], {{.*}}: f16[6]) -> f16[6,6] {
2613-
; CHECK-DAG: f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
2614-
; CHECK-DAG: f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
2625+
if (IsCuda()) {
2626+
MatchOptimizedHlo(hlo_text, R"(
2627+
; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6], {{.*}}: f16[6]) -> f16[6,6] {
2628+
; CHECK-DAG: f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
2629+
; CHECK-DAG: f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
26152630
)");
2631+
} else {
2632+
MatchOptimizedHlo(hlo_text, R"(
2633+
; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6], {{.*}}: f16[6]) -> f16[6,6] {
2634+
; CHECK-DAG: f16[6,12]{1,0}
2635+
; CHECK-DAG: f16[12,6]{1,0}
2636+
)");
2637+
}
26162638
}
26172639

26182640
// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and
@@ -2657,11 +2679,19 @@ ENTRY test {
26572679
})";
26582680

26592681
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
2660-
MatchOptimizedHlo(hlo_text, R"(
2661-
; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6]) -> f16[6,6] {
2662-
; CHECK-DAG: f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
2663-
; CHECK-DAG: f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
2682+
if (IsCuda()) {
2683+
MatchOptimizedHlo(hlo_text, R"(
2684+
; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6]) -> f16[6,6] {
2685+
; CHECK-DAG: f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
2686+
; CHECK-DAG: f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
26642687
)");
2688+
} else {
2689+
MatchOptimizedHlo(hlo_text, R"(
2690+
; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6]) -> f16[6,6] {
2691+
; CHECK-DAG: f16[6,12]{1,0}
2692+
; CHECK-DAG: f16[12,6]{1,0}
2693+
)");
2694+
}
26652695
}
26662696

26672697
TEST_F(CublasLtGemmRewriteTest, MatrixBiasReluActivationF16) {
@@ -2757,11 +2787,19 @@ ENTRY test {
27572787
})";
27582788

27592789
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
2760-
MatchOptimizedHlo(hlo_text, R"(
2761-
; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6], {{.*}}: f16[6]) -> f16[6,6] {
2762-
; CHECK-DAG: f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
2763-
; CHECK-DAG: f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
2790+
if (IsCuda()) {
2791+
MatchOptimizedHlo(hlo_text, R"(
2792+
; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6], {{.*}}: f16[6]) -> f16[6,6] {
2793+
; CHECK-DAG: f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
2794+
; CHECK-DAG: f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
27642795
)");
2796+
} else {
2797+
MatchOptimizedHlo(hlo_text, R"(
2798+
; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6], {{.*}}: f16[6]) -> f16[6,6] {
2799+
; CHECK-DAG: f16[6,12]{1,0}
2800+
; CHECK-DAG: f16[12,6]{1,0}
2801+
)");
2802+
}
27652803
}
27662804

27672805
// For bfloat16, the sizes of all dimensions of the operands are required to be
@@ -2899,11 +2937,19 @@ ENTRY test {
28992937
})";
29002938

29012939
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
2902-
MatchOptimizedHlo(hlo_text, R"(
2903-
; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6], {{.*}}: bf16[6]) -> bf16[6,6] {
2904-
; CHECK-DAG: bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
2905-
; CHECK-DAG: bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
2940+
if (IsCuda()) {
2941+
MatchOptimizedHlo(hlo_text, R"(
2942+
; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6], {{.*}}: bf16[6]) -> bf16[6,6] {
2943+
; CHECK-DAG: bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
2944+
; CHECK-DAG: bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
29062945
)");
2946+
} else {
2947+
MatchOptimizedHlo(hlo_text, R"(
2948+
; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6], {{.*}}: bf16[6]) -> bf16[6,6] {
2949+
; CHECK-DAG: bf16[6,12]{1,0}
2950+
; CHECK-DAG: bf16[12,6]{1,0}
2951+
)");
2952+
}
29072953
}
29082954

29092955
// For bfloat16, the operands are padded if necessary on Ampere and newer
@@ -2955,11 +3001,19 @@ ENTRY test {
29553001
})";
29563002

29573003
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
2958-
MatchOptimizedHlo(hlo_text, R"(
2959-
; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6]) -> bf16[6,6] {
2960-
; CHECK-DAG: bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
2961-
; CHECK-DAG: bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
3004+
if (IsCuda()) {
3005+
MatchOptimizedHlo(hlo_text, R"(
3006+
; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6]) -> bf16[6,6] {
3007+
; CHECK-DAG: bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
3008+
; CHECK-DAG: bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
29623009
)");
3010+
} else {
3011+
MatchOptimizedHlo(hlo_text, R"(
3012+
; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6]) -> bf16[6,6] {
3013+
; CHECK-DAG: bf16[6,12]{1,0}
3014+
; CHECK-DAG: bf16[12,6]{1,0}
3015+
)");
3016+
}
29633017
}
29643018

29653019
// For bfloat16, the operands are padded if necessary on Ampere and newer
@@ -3018,11 +3072,19 @@ ENTRY test {
30183072
30193073
)";
30203074
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
3021-
MatchOptimizedHlo(hlo_text, R"(
3022-
; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6], {{.*}}: bf16[6]) -> bf16[6,6] {
3023-
; CHECK-DAG: bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
3024-
; CHECK-DAG: bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
3075+
if (IsCuda()) {
3076+
MatchOptimizedHlo(hlo_text, R"(
3077+
; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6], {{.*}}: bf16[6]) -> bf16[6,6] {
3078+
; CHECK-DAG: bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
3079+
; CHECK-DAG: bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
30253080
)");
3081+
} else {
3082+
MatchOptimizedHlo(hlo_text, R"(
3083+
; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6], {{.*}}: bf16[6]) -> bf16[6,6] {
3084+
; CHECK-DAG: bf16[6,12]{1,0}
3085+
; CHECK-DAG: bf16[12,6]{1,0}
3086+
)");
3087+
}
30263088
}
30273089

30283090
TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF64) {

xla/stream_executor/gpu/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,7 @@ xla_test(
917917
"//xla/backends/gpu/runtime:buffer_debug_log_proto_cc",
918918
"//xla/backends/gpu/runtime:buffer_debug_log_structs",
919919
"//xla/backends/gpu/runtime:thunk_id",
920+
"//xla/service:platform_util",
920921
"//xla/stream_executor:device_address",
921922
"//xla/stream_executor:platform",
922923
"//xla/stream_executor:platform_manager",

xla/stream_executor/gpu/buffer_debug_log_test.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@ limitations under the License.
2626
#include <gtest/gtest.h>
2727
#include "absl/status/status.h"
2828
#include "absl/status/status_matchers.h"
29+
#include "absl/strings/ascii.h"
2930
#include "xla/backends/gpu/runtime/buffer_debug_log.pb.h"
3031
#include "xla/backends/gpu/runtime/buffer_debug_log_structs.h"
3132
#include "xla/backends/gpu/runtime/thunk_id.h"
33+
#include "xla/service/platform_util.h"
3234
#include "xla/stream_executor/device_address.h"
3335
#include "xla/stream_executor/platform.h"
3436
#include "xla/stream_executor/platform_manager.h"
@@ -47,8 +49,9 @@ using ::xla::gpu::ThunkId;
4749
class BufferDebugLogTest : public ::testing::Test {
4850
protected:
4951
void SetUp() override {
50-
TF_ASSERT_OK_AND_ASSIGN(platform_,
51-
PlatformManager::PlatformWithName("CUDA"));
52+
auto name = absl::AsciiStrToUpper(
53+
xla::PlatformUtil::CanonicalPlatformName("gpu").value());
54+
TF_ASSERT_OK_AND_ASSIGN(platform_, PlatformManager::PlatformWithName(name));
5255
TF_ASSERT_OK_AND_ASSIGN(executor_, platform_->ExecutorForDevice(0));
5356
TF_ASSERT_OK_AND_ASSIGN(stream_, executor_->CreateStream(std::nullopt));
5457
allocator_ =

xla/tests/sample_file_test.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ TEST_F(SampleFileTest, Convolution) {
6363
.mutable_debug_options()
6464
.set_xla_cpu_parallel_codegen_split_count(1);
6565

66+
module->mutable_config().mutable_debug_options().set_xla_gpu_autotune_level(
67+
4);
6668
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0.01}));
6769
}
6870

xla/tsl/platform/default/build_config_root.bzl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ def tf_gpu_tests_tags():
4444

4545
# terminology changes: saving tf_cuda_* for compatibility
4646
def tf_cuda_tests_tags():
47-
return tf_gpu_tests_tags()
47+
if is_cuda_configured():
48+
return ["requires-gpu-cuda", "gpu"] + gpu_test_tags()
49+
else:
50+
return []
4851

4952
def tf_has_tag(kwargs, tag):
5053
return ("tags" in kwargs and kwargs["tags"] != None and tag in kwargs["tags"])

0 commit comments

Comments
 (0)