Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions xla/backends/gpu/codegen/triton/fusion_emitter_large_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,6 @@ class TritonGemmTest : public GpuCodegenTest {
.gpu_compute_capability();
}

void SetUp() override {
if (std::holds_alternative<se::RocmComputeCapability>(
GetGpuComputeCapability())) {
GTEST_SKIP() << "Not supported on ROCm until Triton is re-enabled.";
}
}

DebugOptions GetDebugOptionsForTest() const override {
DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
debug_options.set_xla_gpu_cublas_fallback(false);
Expand All @@ -52,6 +45,10 @@ class TritonGemmTest : public GpuCodegenTest {
};

TEST_F(TritonGemmTest, IndexUsing64Bits) {
if (std::holds_alternative<se::RocmComputeCapability>(
GetGpuComputeCapability())) {
GTEST_SKIP() << "Not enough memory on ROCm.";
}
const char* kHloTextRef = R"(
HloModule r

Expand Down Expand Up @@ -135,10 +132,22 @@ ENTRY e {
EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
}

using TritonNormalizationTest = GpuCodegenTest;
class TritonNormalizationTest : public GpuCodegenTest {
public:
se::GpuComputeCapability GetGpuComputeCapability() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability();
}
};

TEST_F(TritonNormalizationTest,
CanEmitDiamondWithInputNumberOfElementsLargerThanInt32Max) {
if (std::holds_alternative<se::RocmComputeCapability>(
GetGpuComputeCapability())) {
GTEST_SKIP() << "Not enough memory on ROCm.";
}
const std::string hlo_text = R"(
HloModule softmax

Expand Down
2 changes: 2 additions & 0 deletions xla/backends/gpu/codegen/triton/support_legacy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ bool IsDotAlgorithmSupportedByTriton(
if (rocm_compute_capability) {
return rocm_compute_capability->has_bf16_dtype_support();
}
return false;
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9:
if (cuda_compute_capability) {
Expand All @@ -260,6 +261,7 @@ bool IsDotAlgorithmSupportedByTriton(
if (rocm_compute_capability) {
return true;
}
return false;
default:
return false;
}
Expand Down