Skip to content

Commit 4fd5ed4

Browse files
authored
[ROCM] added a cudnn switch of conv2d for rocm platform (#31836) (#31932)
1 parent 9b40cb8 commit 4fd5ed4

File tree

6 files changed

+61
-1
lines changed

6 files changed

+61
-1
lines changed

paddle/fluid/platform/flags.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,3 +564,15 @@ DEFINE_string(tracer_mkldnn_ops_on, "",
564564
*/
565565
DEFINE_string(tracer_mkldnn_ops_off, "",
566566
"List of OneDNN operation types to be turned off");
567+
568+
/**
569+
* CUDNN related FLAG
570+
* Name: conv2d_disable_cudnn
571+
* Since Version:
572+
* Value Range: bool, default=false
573+
* Example:
574+
* Note: Disable cudnn in conv2d.
575+
*/
576+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
577+
DEFINE_bool(conv2d_disable_cudnn, false, "Disable cudnn in conv2d");
578+
#endif

paddle/fluid/pybind/global_value_getter_setter.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ DECLARE_uint64(conv_workspace_size_limit);
7272
DECLARE_bool(cudnn_batchnorm_spatial_persistent);
7373
DECLARE_bool(cudnn_deterministic);
7474
DECLARE_bool(cudnn_exhaustive_search);
75+
DECLARE_bool(conv2d_disable_cudnn);
7576
// data processing
7677
DECLARE_bool(enable_cublas_tensor_op_math);
7778
// device management
@@ -367,7 +368,8 @@ static void RegisterGlobalVarGetterSetter() {
367368
FLAGS_fraction_of_cuda_pinned_memory_to_use,
368369
FLAGS_fraction_of_gpu_memory_to_use, FLAGS_initial_gpu_memory_in_mb,
369370
FLAGS_reallocate_gpu_memory_in_mb, FLAGS_enable_cublas_tensor_op_math,
370-
FLAGS_selected_gpus, FLAGS_sync_nccl_allreduce);
371+
FLAGS_selected_gpus, FLAGS_sync_nccl_allreduce,
372+
FLAGS_conv2d_disable_cudnn);
371373
#endif
372374
#ifdef PADDLE_WITH_XPU
373375
REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_selected_xpus);

python/paddle/fluid/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def __bootstrap__():
230230
'gpu_allocator_retry_time',
231231
'local_exe_sub_scope_limit',
232232
'gpu_memory_limit_mb',
233+
'conv2d_disable_cudnn',
233234
]
234235
core.init_gflags(["--tryfromenv=" + ",".join(read_env_flags)])
235236
core.init_glog(sys.argv[0])

python/paddle/fluid/layers/nn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1603,6 +1603,10 @@ def _get_default_param_initializer():
16031603

16041604
pre_bias = helper.create_variable_for_type_inference(dtype)
16051605

1606+
if (core.is_compiled_with_cuda() and paddle.fluid.get_flags(
1607+
"FLAGS_conv2d_disable_cudnn")["FLAGS_conv2d_disable_cudnn"]):
1608+
use_cudnn = False
1609+
16061610
helper.append_op(
16071611
type=l_type,
16081612
inputs={

python/paddle/fluid/tests/unittests/test_conv2d_op.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,5 +1465,41 @@ def run_7():
14651465
self.assertRaises(ValueError, run_7)
14661466

14671467

1468+
# --------- test environment variable ------
1469+
@unittest.skipIf(
1470+
not (core.is_compiled_with_cuda() or core.is_compiled_with_rocm()),
1471+
"core is not compiled with CUDA or ROCM")
1472+
class TestConv2DEnviron(unittest.TestCase):
1473+
def run_conv2d_api(self):
1474+
inputs = fluid.layers.data(
1475+
shape=[2, 3, 5, 5],
1476+
append_batch_size=False,
1477+
name="inputs",
1478+
dtype="float32")
1479+
fluid.layers.conv2d(
1480+
input=inputs,
1481+
num_filters=4,
1482+
filter_size=[3, 3],
1483+
stride=[1, 1],
1484+
padding=0,
1485+
dilation=[1, 1],
1486+
groups=1,
1487+
data_format="NCHW")
1488+
1489+
x_var = paddle.uniform((2, 3, 5, 5), dtype="float32", min=-1., max=1.)
1490+
conv = paddle.nn.Conv2D(
1491+
in_channels=3,
1492+
out_channels=4,
1493+
kernel_size=(3, 3),
1494+
data_format="NCHW")
1495+
y_var = conv(x_var)
1496+
1497+
def test_environ(self):
1498+
fluid.set_flags({'FLAGS_conv2d_disable_cudnn': False})
1499+
self.run_conv2d_api()
1500+
fluid.set_flags({'FLAGS_conv2d_disable_cudnn': True})
1501+
self.run_conv2d_api()
1502+
1503+
14681504
if __name__ == '__main__':
14691505
unittest.main()

python/paddle/nn/layer/conv.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import numpy as np
2727

28+
from ...fluid import get_flags
2829
from ...fluid import core
2930
from ...device import get_cudnn_version
3031
from ...fluid.dygraph import layers
@@ -644,6 +645,10 @@ def __init__(self,
644645
bias_attr=bias_attr,
645646
data_format=data_format)
646647

648+
if (core.is_compiled_with_cuda() and get_flags(
649+
"FLAGS_conv2d_disable_cudnn")["FLAGS_conv2d_disable_cudnn"]):
650+
self._use_cudnn = False
651+
647652
def forward(self, x):
648653
if self._padding_mode != 'zeros':
649654
x = F.pad(x,

0 commit comments

Comments
 (0)