Skip to content

Commit 5f65ff9

Browse files
authored
[hybrid performance] Optimize pipeline send wait (#34086)
1 parent 9cda059 commit 5f65ff9

File tree

2 files changed

+156
-14
lines changed

2 files changed

+156
-14
lines changed

paddle/fluid/operators/nop_op.cc

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#include <string>
15+
16+
#include "paddle/fluid/framework/op_registry.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
class NopOp : public framework::OperatorWithKernel {
22+
public:
23+
using framework::OperatorWithKernel::OperatorWithKernel;
24+
25+
void InferShape(framework::InferShapeContext* ctx) const override {}
26+
27+
protected:
28+
framework::OpKernelType GetExpectedKernelType(
29+
const framework::ExecutionContext& ctx) const override {
30+
return framework::OpKernelType(framework::proto::VarType::FP32,
31+
ctx.GetPlace());
32+
}
33+
};
34+
35+
class NopOpMaker : public framework::OpProtoAndCheckerMaker {
36+
public:
37+
void Make() {
38+
AddInput("X", "(Tensor) The input tensor of nop op.").AsDuplicable();
39+
AddOutput("Out", "(Tensor) The output tensor of nop op.").AsDuplicable();
40+
AddComment(R"DOC(
41+
Nop Operator
42+
43+
Do nothing, except let the input and output tensors occupy the memory and
44+
establish the dependency between input and output tensors.
45+
)DOC");
46+
}
47+
};
48+
49+
template <typename T>
50+
class NopKernel : public framework::OpKernel<T> {
51+
public:
52+
void Compute(const framework::ExecutionContext& ctx) const override {}
53+
};
54+
55+
} // namespace operators
56+
} // namespace paddle
57+
58+
namespace ops = paddle::operators;
59+
60+
REGISTER_OP_WITHOUT_GRADIENT(nop, ops::NopOp, ops::NopOpMaker);
61+
62+
REGISTER_OP_CPU_KERNEL(nop, ops::NopKernel<float>);
63+
64+
REGISTER_OP_CUDA_KERNEL(nop, ops::NopKernel<float>);
65+
66+
REGISTER_OP_NPU_KERNEL(nop, ops::NopKernel<float>);

python/paddle/fluid/optimizer.py

Lines changed: 90 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4221,6 +4221,8 @@ def __init__(self, optimizer, num_microbatches=1, start_cpu_core_id=0):
42214221
self._param_device_map = None
42224222
self._pipeline_pair = []
42234223
self._pp_ring_map = dict()
4224+
self.output_var_to_op = None
4225+
self.input_var_to_op = None
42244226

42254227
# insert allreduce op to sync global information for global
42264228
# gradient clip and amp
@@ -4657,6 +4659,9 @@ def _check_validation(self, block):
46574659
int(self._op_role.Optimize),
46584660
int(self._op_role.Backward) | int(self._op_role.Loss),
46594661
]
4662+
pre_stage_id = None
4663+
decrease_flag = False
4664+
in_optimize = False
46604665
for op in block.ops:
46614666
if not op._has_kernel(op.type):
46624667
assert op.type == "conditional_block" and (
@@ -4666,25 +4671,49 @@ def _check_validation(self, block):
46664671
assert op.has_attr(self._op_role_key), (
46674672
"op ({}) has no {} attribute.".format(op.type,
46684673
self._op_role_key))
4669-
assert int(op.attr(self._op_role_key)) in valid_op_role_value, \
4674+
op_role = op.attr(self._op_role_key)
4675+
assert int(op_role) in valid_op_role_value, \
46704676
"op_role {} for op {} must be one of {}".format(
4671-
op.attr(self._op_role_key),
4677+
op_role,
46724678
op.type,
46734679
valid_op_role_value)
4680+
if int(op_role) == int(self._op_role.Optimize):
4681+
in_optimize = True
4682+
46744683
assert op.has_attr(self._op_device_key), (
46754684
"op ({}) has no {} attribute.".format(op.type,
46764685
self._op_device_key))
46774686

46784687
device = op.attr(self._op_device_key)
46794688
assert device, ("op_device attribute for op "
46804689
"{} has not been set.".format(op.type))
4681-
if device == "gpu:all": continue
4690+
if device == "gpu:all" or device == "npu:all": continue
4691+
46824692
dev_type = device.split(':')[0]
4693+
stage_id = int(device.split(':')[1])
46834694
assert dev_type == "gpu" or dev_type == 'npu', (
46844695
"Now only gpu and npu devices are supported "
46854696
"for pipeline parallelism.")
4686-
if not device in device_list:
4697+
4698+
if device not in device_list:
46874699
device_list.append(device)
4700+
4701+
if not in_optimize:
4702+
if pre_stage_id is not None:
4703+
interval = stage_id - pre_stage_id
4704+
assert abs(interval) <= 1, \
4705+
"The stage interval of two consecutive ops in the pipeline must be < = 1," \
4706+
"but the interval of op={} and prev op is {}".format(op, interval)
4707+
# stage must be in order, such as Forward(0 1 2 3 4), Backward(4 3 2 1 0)
4708+
# if stage is unordered, such as Forward(0 1 2 3 4 3 4), will report error
4709+
if interval == -1:
4710+
decrease_flag = True
4711+
if interval == 1:
4712+
assert decrease_flag is False, \
4713+
"Pipeline stage must be in order, " \
4714+
"please check the stage of op={}".format(op)
4715+
pre_stage_id = stage_id
4716+
46884717
return device_list
46894718

46904719
def _insert_sendrecv_ops_for_boundaries(self, block):
@@ -4826,14 +4855,16 @@ def _insert_send_recv(cur_id, prev_id):
48264855
})
48274856
extra_index_info['index'] += 1
48284857
insert_index = None
4858+
48294859
if int(op_role) == int(self._op_role.Backward):
48304860
insert_index = extra_index_info[
48314861
'first_optimize_index']
48324862
new_op_role = self._op_role.Optimize
48334863
else:
48344864
insert_index = index
48354865
new_op_role = self._op_role.Backward
4836-
block._insert_op_without_sync(
4866+
4867+
sync_comm_op = block._insert_op_without_sync(
48374868
index=insert_index + extra_index_info['index'],
48384869
type='c_sync_comm_stream',
48394870
inputs={'X': [var]},
@@ -4843,8 +4874,11 @@ def _insert_send_recv(cur_id, prev_id):
48434874
self._op_role_key: new_op_role,
48444875
'ring_id': ring_id,
48454876
})
4877+
48464878
if int(op_role) == int(self._op_role.Forward):
4879+
sync_comm_op._set_attr('pipeline_flag', '')
48474880
extra_index_info['index'] += 1
4881+
48484882
var_shape = list(var.shape)
48494883
var_shape[0] = self.micro_batch_size if var_shape[
48504884
0] < 0 else var_shape[0]
@@ -5153,17 +5187,55 @@ def _get_input_output_info(self, block):
51535187
Get info of op input and output.
51545188
'''
51555189
# A map from output var to op which generate it.
5156-
self.output_var_to_op = dict()
5190+
output_var_to_op = defaultdict(list)
51575191
# A map from var to op which takes it as input.
5158-
self.input_var_to_op = dict()
5192+
input_var_to_op = defaultdict(list)
51595193

5160-
for index, op in enumerate(list(block.ops)):
5194+
for index, op in enumerate(block.ops):
51615195
for var_name in op.input_arg_names:
5162-
ops = self.input_var_to_op.setdefault(var_name, [])
5163-
ops.append([op, index])
5196+
input_var_to_op[var_name].append([op, index])
51645197
for var_name in op.output_arg_names:
5165-
ops = self.output_var_to_op.setdefault(var_name, [])
5166-
ops.append([op, index])
5198+
output_var_to_op[var_name].append([op, index])
5199+
5200+
return output_var_to_op, input_var_to_op
5201+
5202+
def _optimize_forward_send_sync(self, program):
5203+
"""
5204+
optimize forward send's sync_comm_stream schedule
5205+
"""
5206+
if self.schedule_mode != '1F1B': return
5207+
5208+
block = program.block(0)
5209+
5210+
backward_recv_index = None
5211+
for index, op in enumerate(block.ops):
5212+
if op.type == 'recv_v2' and self._is_backward_op(op):
5213+
backward_recv_index = index
5214+
break
5215+
5216+
if backward_recv_index is None: return
5217+
5218+
offset = 0
5219+
for index, op in enumerate(list(block.ops)):
5220+
if index >= backward_recv_index: break
5221+
if op.type == 'c_sync_comm_stream' and op.has_attr('pipeline_flag'):
5222+
var_name = op.input_arg_names[0]
5223+
var = block.var(var_name)
5224+
block._remove_op(index + offset, sync=False)
5225+
offset -= 1
5226+
# NOTE:
5227+
# 1. When the backward recv is completed, it indicates
5228+
# that the forward send is completed too. So we only need
5229+
# to use the NOP op to prevent memory release.
5230+
# 2. Because we removed sync_comm_op,
5231+
# we will insert NOP after recv_op.
5232+
block._insert_op_without_sync(
5233+
index=backward_recv_index,
5234+
type='nop',
5235+
inputs={'X': [var]},
5236+
outputs={'Out': [var]},
5237+
attrs={self._op_role_key: self._op_role.Backward})
5238+
block._sync_with_cpp()
51675239

51685240
def minimize(self,
51695241
loss,
@@ -5200,7 +5272,8 @@ def minimize(self,
52005272
loss, startup_program, parameter_list, no_grad_set)
52015273
self._param_device_map = self._origin_optimizer._param_device_map
52025274

5203-
self._get_input_output_info(main_block)
5275+
self.output_var_to_op, self.input_var_to_op = \
5276+
self._get_input_output_info(main_block)
52045277
# Step1: add default op_device attribute for ops.
52055278
self._add_op_device_attr(main_block)
52065279
device_list = self._check_validation(main_block)
@@ -5229,6 +5302,10 @@ def device_cmp(device1, device2):
52295302
for p in program_list:
52305303
self._create_vars(p.global_block(), main_block)
52315304

5305+
self.local_rank %= len(device_list)
5306+
# Step3.5: optimize forward send sync_comm to overlap send and recv
5307+
self._optimize_forward_send_sync(program_list[self.local_rank])
5308+
52325309
# Step4: Special Case: process persistable vars that exist in
52335310
# multiple sections
52345311
# FIXME
@@ -5238,7 +5315,6 @@ def device_cmp(device1, device2):
52385315
# Step5: Add sub blocks for section programs
52395316
self._add_sub_blocks(main_block, program_list)
52405317

5241-
self.local_rank %= len(device_list)
52425318
place_list = []
52435319
for dev in device_list:
52445320
dev_index = int(dev.split(":")[1])

0 commit comments

Comments
 (0)