Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 31 additions & 37 deletions python/paddle/fluid/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4397,6 +4397,10 @@ def _is_loss_grad_op(self, op):
return op_role & int(self._op_role.Backward) and op_role & int(
self._op_role.Loss)

def _is_forward_op(self, op):
return self._op_role_key in op.attr_names and (
int(op.attr(self._op_role_key)) == int(self._op_role.Forward))

def _is_backward_op(self, op):
return self._op_role_key in op.attr_names and (
int(op.attr(self._op_role_key)) & int(self._op_role.Backward))
Expand Down Expand Up @@ -4705,10 +4709,6 @@ def _check_validation(self, block):
int(self._op_role.Optimize),
int(self._op_role.Backward) | int(self._op_role.Loss),
]
pre_stage_id = None
decrease_flag = False
in_optimize = False
in_forward = True
for op in block.ops:
if not op._has_kernel(op.type):
assert op.type == "conditional_block" and (
Expand All @@ -4724,10 +4724,6 @@ def _check_validation(self, block):
op_role,
op.type,
valid_op_role_value)
if int(op_role) == int(self._op_role.Optimize):
in_optimize = True
if int(op_role) == int(self._op_role.Backward):
in_forward = False

assert op.has_attr(self._op_device_key), (
"op ({}) has no {} attribute.".format(op.type,
Expand All @@ -4739,34 +4735,13 @@ def _check_validation(self, block):
if device == f"{self._device}:all": continue

dev_type = device.split(':')[0]
stage_id = int(device.split(':')[1])
assert dev_type == "gpu" or dev_type == 'npu', (
"Now only gpu and npu devices are supported "
"for pipeline parallelism.")

if device not in device_list:
device_list.append(device)

if not in_optimize:
if pre_stage_id is not None:
interval = stage_id - pre_stage_id
assert abs(interval) <= 1, \
"The stage interval of two consecutive ops in the pipeline must be < = 1," \
"but the interval of op={} and prev op is {}".format(op, interval)
# stage must be in order, such as Forward(0 1 2 3 4), Backward(4 3 2 1 0)
# if stage is unordered, such as Forward(0 1 2 3 4 3 4), will report error
if in_forward:
assert interval >= 0, \
"Pipeline stage must be sequential increment in Forward, prev_stage={}, " \
"please check the stage of op={}".format(pre_stage_id, op)
else:
# FIXME(wangxi): recompute check failed
pass
#assert interval <=0, \
# "Pipeline stage must be sequential decrement in Backward, prev_stage={}, " \
# "please check the stage of op={}".format(pre_stage_id, op)
pre_stage_id = stage_id

return device_list

def _insert_sendrecv_ops_for_boundaries(self, block):
Expand Down Expand Up @@ -4820,6 +4795,25 @@ def _insert_sendrecv_ops_for_boundaries(self, block):

device_type = cur_device.split(':')[0] + ':'

def _check_stage(cur_id, prev_id):
# check send/recv stage valid
is_forward = self._is_forward_op(op)
is_backward = self._is_backward_op(op)
assert is_forward or is_backward, \
'send/recv in pipeline should only be inserted in forward or backward,' \
'please check the op_role of op={}'.format(op)

if is_forward:
assert prev_id < cur_id, \
"In forward, send/recv can only be passed forward, but now " \
"prev_stage={} great than cur_stage={}, please check op_device of op={}".format(
prev_id, cur_id, op)
elif is_backward:
assert prev_id > cur_id, \
"In backward, send/recv can only be passed backward, but now " \
"prev_stage={} less than cur_stage={}, please check op_device of op={}".format(
prev_id, cur_id, op)

def _insert_send_recv(cur_id, prev_id):
cur_dev = device_type + str(cur_id)
prev_dev = device_type + str(prev_id)
Expand Down Expand Up @@ -4890,9 +4884,9 @@ def _insert_send_recv(cur_id, prev_id):
var_shape[0] = self.micro_batch_size if var_shape[
0] < 0 else var_shape[0]

numel = np.prod(var.shape)
assert numel % self.mp_degree == 0, \
"The numel={} must be divisible by mp_degree={}".format(numel, self.mp_degree)
numel = np.prod(var_shape)
use_mp = (self.mp_degree > 1) and (
numel % self.mp_degree == 0)

if 'subprog' in var.name:
# For recompute, if the checkpoints var is layer_norm_6.tmp_2
Expand All @@ -4919,6 +4913,8 @@ def _insert_send_recv(cur_id, prev_id):
extra_index_info['index'] += 1
return

_check_stage(cur_id, prev_id)

block._insert_op_without_sync(
index=index + extra_index_info['index'],
type='c_sync_calc_stream',
Expand All @@ -4931,8 +4927,7 @@ def _insert_send_recv(cur_id, prev_id):
extra_index_info['index'] += 1
block._insert_op_without_sync(
index=index + extra_index_info['index'],
type='send_v2'
if self.mp_degree == 1 else 'partial_send',
type='send_v2' if not use_mp else 'partial_send',
inputs={'X': var},
attrs={
self._op_device_key: prev_dev,
Expand Down Expand Up @@ -4968,8 +4963,7 @@ def _insert_send_recv(cur_id, prev_id):
extra_index_info['index'] += 1
block._insert_op_without_sync(
index=index + extra_index_info['index'],
type='recv_v2'
if self.mp_degree == 1 else 'partial_recv',
type='recv_v2' if not use_mp else 'partial_recv',
outputs={'Out': [var]},
attrs={
'out_shape': var_shape,
Expand All @@ -4984,7 +4978,7 @@ def _insert_send_recv(cur_id, prev_id):
'id': self.mp_rank,
})
extra_index_info['index'] += 1
if self.mp_degree > 1:
if use_mp:
block._insert_op_without_sync(
index=index + extra_index_info['index'],
type='partial_allgather',
Expand Down