Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,16 @@ message ExecutionStrategy {
optional bool use_thread_barrier = 4 [ default = false ];
}

message GradientScaleConfig {
// Optional value ['avg', 'sum', 'customized']
// If avg, loss@grad will be divided by the number of devices,
// that is, the gradient will be accumulated and averaged among
// multiple devices.
// Else if sum, the gradient will accumulated among multiple
// devices.
optional string scale_strategy = 1 [ default = 'avg' ];
}

message AsyncConfig {
optional int32 k_steps = 1 [ default = -1 ];
optional int32 max_merge_var_num = 2 [ default = 1 ];
Expand Down Expand Up @@ -194,6 +204,7 @@ message DistributedStrategy {
optional TensorParallelConfig tensor_parallel_configs = 113;
optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202;
optional GradientScaleConfig gradient_scale_configs = 203;
}

message DistributedJobInfo {
Expand Down
22 changes: 22 additions & 0 deletions python/paddle/distributed/fleet/base/distributed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,28 @@ def build_strategy(self, strategy):
getattr(self.strategy.build_strategy,
f.name).extend(getattr(strategy, f.name))

@property
def gradient_scale_configs(self):
"""
Set the strategy of gradient scale
Examples:

.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.gradient_scale_configs = {'scale_strategy': 'avg'}

Note that, strategy must be in 'avg', 'sum' or 'customized'
"""
return get_msg_dict(self.strategy.gradient_scale_configs)

@gradient_scale_configs.setter
@is_strict_auto
def gradient_scale_configs(self, config):
check_configs_key(self.strategy.gradient_scale_configs, config,
'gradient_scale_configs')
assign_configs_value(self.strategy.gradient_scale_configs, config)

@property
def a_sync(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .meta_optimizer_base import MetaOptimizerBase
from ..base.private_helper_function import wait_server_ready
import logging
from paddle.static import BuildStrategy

__all__ = []

Expand Down Expand Up @@ -147,6 +148,17 @@ def _try_to_compile(self, startup_program, main_program, loss):
local_build_strategy.nccl_comm_num = \
dist_strategy.nccl_comm_num

gradient_scale_configs = self.user_defined_strategy.gradient_scale_configs
scale_strategys = {
'avg': BuildStrategy.GradientScaleStrategy.CoeffNumDevice,
'sum': BuildStrategy.GradientScaleStrategy.One,
'customized': BuildStrategy.GradientScaleStrategy.Customized,
}
assert gradient_scale_configs['scale_strategy'] in scale_strategys, \
"gradient_scale_configs.scale_strategy must be 'avg', 'sum' or 'customized'"
local_build_strategy.gradient_scale_strategy = \
scale_strategys[gradient_scale_configs['scale_strategy']]

if self.user_defined_strategy.recompute == True:
logging.warn(
"set enable_sequential_execution=True since you have enable the recompute strategy"
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_collective_wait)
LIST(REMOVE_ITEM TEST_OPS test_memcpy_op)
LIST(REMOVE_ITEM TEST_OPS test_raw_program_optimizer)
LIST(REMOVE_ITEM TEST_OPS test_fleet_gradient_scale)
endif()

if(WIN32)
Expand Down
73 changes: 73 additions & 0 deletions python/paddle/fluid/tests/unittests/test_fleet_gradient_scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest

import paddle
import paddle.fluid as fluid
import paddle.distributed.fleet as fleet
import numpy as np
import os


class TestGradientScale(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_TRAINER_ID"] = "0"
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001"

def mlp(self, input_x, input_y, hid_dim=128, label_dim=2):
fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh')
fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh')
prediction = paddle.static.nn.fc(x=[fc_2],
size=label_dim,
activation='softmax')
cost = paddle.nn.functional.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.mean(x=cost)
return avg_cost

def gen_data(self):
return {
"x": np.random.random(size=(128, 32)).astype('float32'),
"y": np.random.randint(
2, size=(128, 1)).astype('int64')
}

def test_single_gpu(self):
paddle.enable_static()
fleet.init(is_collective=True)
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
strategy = fleet.DistributedStrategy()
strategy.gradient_scale_configs = {'scale_strategy': 'sum'}
with fluid.program_guard(main_program, startup_program):
with fluid.unique_name.guard():
input_x = paddle.static.data(
name="x", shape=[None, 32], dtype='float32')
input_y = paddle.static.data(
name="y", shape=[None, 1], dtype='int64')
cost = self.mlp(input_x=input_x, input_y=input_y)
output_name = cost.name
optimizer = fleet.distributed_optimizer(fluid.optimizer.Adam(),
strategy)
optimizer.minimize(cost)

final_strategy = fleet._final_strategy()
assert final_strategy.gradient_scale_configs['scale_strategy'] == 'sum'


if __name__ == "__main__":
unittest.main()