Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions paddle/fluid/operators/nop_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>

#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

class NopOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
}
};

class NopOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) The input tensor of nop op.").AsDuplicable();
AddOutput("Out", "(Tensor) The output tensor of nop op.").AsDuplicable();
AddComment(R"DOC(
Nop Operator

Do nothing, except let the input and output tensors occupy the memory and
establish the dependency between input and output tensors.
)DOC");
}
};

template <typename T>
class NopKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_WITHOUT_GRADIENT(nop, ops::NopOp, ops::NopOpMaker);

REGISTER_OP_CPU_KERNEL(nop, ops::NopKernel<float>);

REGISTER_OP_CUDA_KERNEL(nop, ops::NopKernel<float>);

REGISTER_OP_NPU_KERNEL(nop, ops::NopKernel<float>);
104 changes: 90 additions & 14 deletions python/paddle/fluid/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4221,6 +4221,8 @@ def __init__(self, optimizer, num_microbatches=1, start_cpu_core_id=0):
self._param_device_map = None
self._pipeline_pair = []
self._pp_ring_map = dict()
self.output_var_to_op = None
self.input_var_to_op = None

# insert allreduce op to sync global information for global
# gradient clip and amp
Expand Down Expand Up @@ -4657,6 +4659,9 @@ def _check_validation(self, block):
int(self._op_role.Optimize),
int(self._op_role.Backward) | int(self._op_role.Loss),
]
pre_stage_id = None
decrease_flag = False
in_optimize = False
for op in block.ops:
if not op._has_kernel(op.type):
assert op.type == "conditional_block" and (
Expand All @@ -4666,25 +4671,49 @@ def _check_validation(self, block):
assert op.has_attr(self._op_role_key), (
"op ({}) has no {} attribute.".format(op.type,
self._op_role_key))
assert int(op.attr(self._op_role_key)) in valid_op_role_value, \
op_role = op.attr(self._op_role_key)
assert int(op_role) in valid_op_role_value, \
"op_role {} for op {} must be one of {}".format(
op.attr(self._op_role_key),
op_role,
op.type,
valid_op_role_value)
if int(op_role) == int(self._op_role.Optimize):
in_optimize = True

assert op.has_attr(self._op_device_key), (
"op ({}) has no {} attribute.".format(op.type,
self._op_device_key))

device = op.attr(self._op_device_key)
assert device, ("op_device attribute for op "
"{} has not been set.".format(op.type))
if device == "gpu:all": continue
if device == "gpu:all" or device == "npu:all": continue

dev_type = device.split(':')[0]
stage_id = int(device.split(':')[1])
assert dev_type == "gpu" or dev_type == 'npu', (
"Now only gpu and npu devices are supported "
"for pipeline parallelism.")
if not device in device_list:

if device not in device_list:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个的原因是啥?

device_list.append(device)

if not in_optimize:
if pre_stage_id is not None:
interval = stage_id - pre_stage_id
assert abs(interval) <= 1, \
"The stage interval of two consecutive ops in the pipeline must be < = 1," \
"but the interval of op={} and prev op is {}".format(op, interval)
# stage must be in order, such as Forward(0 1 2 3 4), Backward(4 3 2 1 0)
# if stage is unordered, such as Forward(0 1 2 3 4 3 4), will report error
if interval == -1:
decrease_flag = True
if interval == 1:
assert decrease_flag is False, \
"Pipeline stage must be in order, " \
"please check the stage of op={}".format(op)
pre_stage_id = stage_id

return device_list

def _insert_sendrecv_ops_for_boundaries(self, block):
Expand Down Expand Up @@ -4826,14 +4855,16 @@ def _insert_send_recv(cur_id, prev_id):
})
extra_index_info['index'] += 1
insert_index = None

if int(op_role) == int(self._op_role.Backward):
insert_index = extra_index_info[
'first_optimize_index']
new_op_role = self._op_role.Optimize
else:
insert_index = index
new_op_role = self._op_role.Backward
block._insert_op_without_sync(

sync_comm_op = block._insert_op_without_sync(
index=insert_index + extra_index_info['index'],
type='c_sync_comm_stream',
inputs={'X': [var]},
Expand All @@ -4843,8 +4874,11 @@ def _insert_send_recv(cur_id, prev_id):
self._op_role_key: new_op_role,
'ring_id': ring_id,
})

if int(op_role) == int(self._op_role.Forward):
sync_comm_op._set_attr('pipeline_flag', '')
extra_index_info['index'] += 1

var_shape = list(var.shape)
var_shape[0] = self.micro_batch_size if var_shape[
0] < 0 else var_shape[0]
Expand Down Expand Up @@ -5153,17 +5187,55 @@ def _get_input_output_info(self, block):
Get info of op input and output.
'''
# A map from output var to op which generate it.
self.output_var_to_op = dict()
output_var_to_op = defaultdict(list)
# A map from var to op which takes it as input.
self.input_var_to_op = dict()
input_var_to_op = defaultdict(list)

for index, op in enumerate(list(block.ops)):
for index, op in enumerate(block.ops):
for var_name in op.input_arg_names:
ops = self.input_var_to_op.setdefault(var_name, [])
ops.append([op, index])
input_var_to_op[var_name].append([op, index])
for var_name in op.output_arg_names:
ops = self.output_var_to_op.setdefault(var_name, [])
ops.append([op, index])
output_var_to_op[var_name].append([op, index])

return output_var_to_op, input_var_to_op

def _optimize_forward_send_sync(self, program):
"""
optimize forward send's sync_comm_stream schedule
"""
if self.schedule_mode != '1F1B': return

block = program.block(0)

backward_recv_index = None
for index, op in enumerate(block.ops):
if op.type == 'recv_v2' and self._is_backward_op(op):
backward_recv_index = index
break

if backward_recv_index is None: return

offset = 0
for index, op in enumerate(list(block.ops)):
if index >= backward_recv_index: break
if op.type == 'c_sync_comm_stream' and op.has_attr('pipeline_flag'):
var_name = op.input_arg_names[0]
var = block.var(var_name)
block._remove_op(index + offset, sync=False)
offset -= 1
# NOTE:
# 1. When the backward recv is completed, it indicates
# that the forward send is completed too. So we only need
# to use the NOP op to prevent memory release.
# 2. Because we removed sync_comm_op,
# we will insert NOP after recv_op.
block._insert_op_without_sync(
index=backward_recv_index,
type='nop',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={self._op_role_key: self._op_role.Backward})
block._sync_with_cpp()

def minimize(self,
loss,
Expand Down Expand Up @@ -5200,7 +5272,8 @@ def minimize(self,
loss, startup_program, parameter_list, no_grad_set)
self._param_device_map = self._origin_optimizer._param_device_map

self._get_input_output_info(main_block)
self.output_var_to_op, self.input_var_to_op = \
self._get_input_output_info(main_block)
# Step1: add default op_device attribute for ops.
self._add_op_device_attr(main_block)
device_list = self._check_validation(main_block)
Expand Down Expand Up @@ -5229,6 +5302,10 @@ def device_cmp(device1, device2):
for p in program_list:
self._create_vars(p.global_block(), main_block)

self.local_rank %= len(device_list)
# Step3.5: optimize forward send sync_comm to overlap send and recv
self._optimize_forward_send_sync(program_list[self.local_rank])

# Step4: Special Case: process persistable vars that exist in
# multiple sections
# FIXME
Expand All @@ -5238,7 +5315,6 @@ def device_cmp(device1, device2):
# Step5: Add sub blocks for section programs
self._add_sub_blocks(main_block, program_list)

self.local_rank %= len(device_list)
place_list = []
for dev in device_list:
dev_index = int(dev.split(":")[1])
Expand Down