Skip to content

Commit 1240e4f

Browse files
author
lilong12
committed
align the default value of some configuration for fleet to that of single cards (#30740)
* update, test=develop
1 parent cb5f043 commit 1240e4f

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

paddle/fluid/framework/distributed_strategy.proto

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ message DistributedStrategy {
139139
optional bool fuse_all_reduce_ops = 18 [ default = true ];
140140
optional int32 fuse_grad_size_in_MB = 19 [ default = 32 ];
141141
optional float fuse_grad_size_in_TFLOPS = 20 [ default = 50 ];
142-
optional bool cudnn_exhaustive_search = 21 [ default = true ];
142+
optional bool cudnn_exhaustive_search = 21 [ default = false ];
143143
optional int32 conv_workspace_size_limit = 22 [ default = 512 ];
144-
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ];
144+
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = false ];
145145
optional bool adaptive_localsgd = 24 [ default = false ];
146146
optional bool fp16_allreduce = 25 [ default = false ];
147147
optional bool sharding = 26 [ default = false ];

python/paddle/distributed/fleet/base/distributed_strategy.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,22 @@ def __init__(self):
115115
116116
"""
117117
self.strategy = distributed_strategy_pb2.DistributedStrategy()
118+
119+
# Set the default values of the following flags to the ones set by users
120+
key = 'FLAGS_cudnn_batchnorm_spatial_persistent'
121+
if core.globals().is_public(key):
122+
self.strategy.cudnn_batchnorm_spatial_persistent = bool(
123+
core.globals()[key])
124+
key = 'FLAGS_conv_workspace_size_limit'
125+
if core.globals().is_public(key):
126+
self.strategy.conv_workspace_size_limit = int(core.globals()[key])
127+
key = 'FLAGS_cudnn_exhaustive_search'
128+
if core.globals().is_public(key):
129+
self.strategy.cudnn_exhaustive_search = bool(core.globals()[key])
130+
key = 'FLAGS_sync_nccl_allreduce'
131+
if core.globals().is_public(key):
132+
self.strategy.sync_nccl_allreduce = bool(core.globals()[key])
133+
118134
self.__lock_attr = True
119135

120136
def __setattr__(self, key, value):

0 commit comments

Comments
 (0)