Skip to content

Commit 911c859

Browse files
authored
optimize pipeline performance with recompute and amp, test=allcase (#34519)
1 parent 1d7b75d commit 911c859

File tree

4 files changed

+87
-11
lines changed

4 files changed

+87
-11
lines changed

python/paddle/fluid/backward.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,13 @@ def _append_backward_ops_with_checkpoints_(
945945
for op_desc in reversed(added_descs):
946946
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
947947
op_desc, cpt.to_text(no_grad_dict[block.idx]), [])
948+
949+
# Set device for grad_op according to forward Op
950+
if op_desc.has_attr(device_attr_name):
951+
op_device = op_desc.attr(device_attr_name)
952+
for g_op_desc in grad_op_desc:
953+
g_op_desc._set_attr(device_attr_name, op_device)
954+
948955
for key in var_name_dict:
949956
_rename_arg_(grad_op_desc, key, var_name_dict[key])
950957
grad_op_descs.extend(grad_op_desc)

python/paddle/fluid/contrib/mixed_precision/fp16_lists.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ def _update_list(self):
150150
'c_identity',
151151
'c_concat',
152152
'c_allreduce_sum',
153+
'concat',
154+
'split',
153155
}
154156

155157
# The set of ops that don't support fp16 calculation

python/paddle/fluid/contrib/mixed_precision/fp16_utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,27 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
110110
cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype)
111111
out_var = block.vars.get(cast_name)
112112
if out_var is None or out_var.dtype != dest_dtype:
113+
op_device = op.attr('op_device')
114+
# NOTE(wangxi): optimize for pipeline, reduce one send.
115+
# if in_var is stop_gradient and prev_op device is `all`,
116+
# set cast_op device to `all`, can reduce send cast_var.
117+
# TODO: need remove this after we unified the dynamic
118+
# and static pipeline interface.
119+
if src_dtype == core.VarDesc.VarType.FP32 and in_var.stop_gradient:
120+
prev_op = None
121+
if in_var.op is op:
122+
prev_op = find_true_prev_op(block.ops, op,
123+
in_var_name)
124+
elif in_var.op is not None:
125+
prev_op = in_var.op
126+
127+
prev_op_device = None
128+
if prev_op is not None:
129+
prev_op_device = prev_op.attr('op_device')
130+
131+
if prev_op_device is not None and 'all' in prev_op_device:
132+
op_device = prev_op_device
133+
113134
out_var = block.create_var(
114135
name=cast_name,
115136
dtype=dest_dtype,
@@ -124,7 +145,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
124145
attrs={
125146
"in_dtype": in_var.dtype,
126147
"out_dtype": out_var.dtype,
127-
"op_device": op.attr("op_device")
148+
"op_device": op_device
128149
})
129150
num_cast_ops += 1
130151
_rename_arg(op, in_var.name, out_var.name)

python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515
import unittest
1616
import paddle
17+
import paddle.fluid as fluid
18+
import paddle.static as static
19+
import paddle.distributed.fleet as fleet
20+
import paddle.distributed.fleet.base.role_maker as role_maker
1721
import os
1822

1923
paddle.enable_static()
@@ -25,26 +29,34 @@ def setUp(self):
2529
os.environ[
2630
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002"
2731

28-
def test_pipeline_optimizer(self):
29-
import paddle.distributed.fleet as fleet
30-
import paddle.distributed.fleet.base.role_maker as role_maker
31-
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
32-
fleet.init(role)
33-
with paddle.fluid.device_guard("gpu:0"):
32+
def net(self):
33+
with static.device_guard("gpu:0"):
3434
input_x = paddle.fluid.layers.data(
3535
name="x", shape=[32], dtype='float32')
3636
input_y = paddle.fluid.layers.data(
3737
name="y", shape=[1], dtype='int64')
38+
input_z = paddle.fluid.layers.data(
39+
name="z", shape=[1], dtype="float32")
40+
with static.device_guard("gpu:all"):
41+
input_z = input_z * 1.0
42+
input_z.stop_gradient = True
3843
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
44+
fc_1 = fc_1 * input_z
3945

40-
with paddle.fluid.device_guard("gpu:1"):
46+
with static.device_guard("gpu:1"):
4147
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
48+
fc_2 = fc_2 * input_z
4249
prediction = paddle.fluid.layers.fc(input=[fc_2],
4350
size=2,
4451
act='softmax')
4552
cost = paddle.fluid.layers.cross_entropy(
4653
input=prediction, label=input_y)
4754
avg_cost = paddle.fluid.layers.mean(x=cost)
55+
return avg_cost
56+
57+
def test_pipeline_optimizer(self):
58+
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
59+
fleet.init(role)
4860

4961
strategy = paddle.distributed.fleet.DistributedStrategy()
5062
strategy.pipeline = True
@@ -53,9 +65,43 @@ def test_pipeline_optimizer(self):
5365
'accumulate_steps': 2
5466
}
5567

56-
optimizer = paddle.fluid.optimizer.Adam(0.01)
57-
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
58-
optimizer.minimize(avg_cost)
68+
train_prog, startup_prog = static.Program(), static.Program()
69+
with static.program_guard(train_prog, startup_prog):
70+
with fluid.unique_name.guard():
71+
avg_cost = self.net()
72+
73+
optimizer = paddle.fluid.optimizer.Adam(0.01)
74+
optimizer = fleet.distributed_optimizer(
75+
optimizer, strategy=strategy)
76+
optimizer.minimize(avg_cost)
77+
78+
def test_pipeline_amp_optimizer(self):
79+
""" test pipeline&amp with device:all """
80+
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
81+
fleet.init(role)
82+
83+
strategy = paddle.distributed.fleet.DistributedStrategy()
84+
strategy.amp = True
85+
strategy.pipeline = True
86+
strategy.pipeline_configs = {
87+
'micro_batch_size': 1,
88+
'accumulate_steps': 2
89+
}
90+
91+
train_prog, startup_prog = static.Program(), static.Program()
92+
with static.program_guard(train_prog, startup_prog):
93+
with fluid.unique_name.guard():
94+
avg_cost = self.net()
95+
96+
optimizer = paddle.fluid.optimizer.Adam(0.01)
97+
optimizer = fleet.distributed_optimizer(
98+
optimizer, strategy=strategy)
99+
optimizer.minimize(avg_cost)
100+
101+
ops = train_prog._pipeline_opt['section_program'].global_block().ops
102+
ops = [op.type for op in ops]
103+
self.assertEqual(ops.count('send_v2'), 1)
104+
self.assertEqual(ops.count('recv_v2'), 1)
59105

60106

61107
if __name__ == "__main__":

0 commit comments

Comments
 (0)