Skip to content

Commit edab8b2

Browse files
[ROCm] Disable Cudnn fusions (#358)
1 parent fe04251 commit edab8b2

2 files changed

Lines changed: 7 additions & 4 deletions

File tree

xla/service/gpu/amdgpu_compiler.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,9 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
122122
stream_executor::RocmSolverContext::Create);
123123
pipeline.AddPass<ConvRewriter>(gpu_version);
124124
pipeline.AddPass<ConvPaddingLegalization>();
125-
auto rcc = std::get<se::RocmComputeCapability>(gpu_version);
126-
pipeline.AddPass<CudnnFusedConvRewriter>(rcc, dnn_version, toolkit_version);
125+
//TODO(rocm): Until #12613 is fixed.
126+
// auto rcc = std::get<se::RocmComputeCapability>(gpu_version);
127+
// pipeline.AddPass<CudnnFusedConvRewriter>(rcc, dnn_version, toolkit_version);
127128

128129
// The conv padding/vectorization passes which we need to get rid of. They
129130
// also leave behind unnecessary tuple/get-tuple-element pairs that

xla/service/gpu/transforms/BUILD

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,9 @@ cc_library(
10091009
xla_test(
10101010
name = "cudnn_fused_conv_rewriter_test",
10111011
srcs = ["cudnn_fused_conv_rewriter_test.cc"],
1012+
tags = [
1013+
"cuda-only", # TODO(rocm): Until #12613 is fixed.
1014+
],
10121015
backend_tags = {
10131016
"gpu_a100": [
10141017
"noasan",
@@ -1017,8 +1020,7 @@ xla_test(
10171020
},
10181021
backends = [
10191022
"gpu_a100",
1020-
"gpu_amd_any",
1021-
] + if_oss(["gpu_any"]),
1023+
],
10221024
shard_count = 10,
10231025
deps = [
10241026
":conv_rewriter",

0 commit comments

Comments
 (0)