From 91662898ec51c024de9f8d706052048b9e48d36e Mon Sep 17 00:00:00 2001 From: WangXi Date: Thu, 12 Aug 2021 11:04:27 +0800 Subject: [PATCH 1/5] NPU use squared_l2_norm in GradientClipByGlobalNorm --- python/paddle/fluid/clip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index 04fb45cd3ae22d..d48cea48a76fd4 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -40,7 +40,7 @@ def _squared_l2_norm(x): This OP returns the squared L2 norm of a tensor. """ - if core.is_compiled_with_npu() or core.is_compiled_with_xpu(): + if core.is_compiled_with_xpu(): square = layers.square(x) sum_square = layers.reduce_sum(square) return sum_square From 5dea68d15b7c3a3722d6c7779dfb153e23b41359 Mon Sep 17 00:00:00 2001 From: WangXi Date: Fri, 13 Aug 2021 10:24:19 +0800 Subject: [PATCH 2/5] pass pipeline check --- python/paddle/fluid/clip.py | 2 +- python/paddle/fluid/optimizer.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index d48cea48a76fd4..04fb45cd3ae22d 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -40,7 +40,7 @@ def _squared_l2_norm(x): This OP returns the squared L2 norm of a tensor. """ - if core.is_compiled_with_xpu(): + if core.is_compiled_with_npu() or core.is_compiled_with_xpu(): square = layers.square(x) sum_square = layers.reduce_sum(square) return sum_square diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 7ad94f4be3eb2f..e908349d9da252 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4750,15 +4750,17 @@ def _check_validation(self, block): if not in_optimize: if pre_stage_id is not None: interval = stage_id - pre_stage_id - assert abs(interval) <= 1, \ - "The stage interval of two consecutive ops in the pipeline must be < = 1," \ - "but the interval of op={} and prev op is {}".format(op, interval) + # assert abs(interval) <= 1, \ + # "The stage interval of two consecutive ops in the pipeline must be < = 1," \ + # "but the interval of op={} and prev op is {}".format(op, interval) + # stage must be in order, such as Forward(0 1 2 3 4), Backward(4 3 2 1 0) # if stage is unordered, such as Forward(0 1 2 3 4 3 4), will report error if in_forward: - assert interval >= 0, \ - "Pipeline stage must be sequential increment in Forward, prev_stage={}, " \ - "please check the stage of op={}".format(pre_stage_id, op) + pass + # assert interval >= 0, \ + # "Pipeline stage must be sequential increment in Forward, prev_stage={}, " \ + # "please check the stage of op={}".format(pre_stage_id, op) else: # FIXME(wangxi): recompute check failed pass From 9abe5ed9231a40a38b59dab3c93a871fbf0fe0fd Mon Sep 17 00:00:00 2001 From: WangXi Date: Fri, 13 Aug 2021 12:31:23 +0800 Subject: [PATCH 3/5] move pipelien check into insert_send_recv --- python/paddle/fluid/optimizer.py | 57 +++++++++++++------------------- 1 file changed, 23 insertions(+), 34 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index e908349d9da252..65accc25361fb2 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4397,6 +4397,10 @@ def _is_loss_grad_op(self, op): return op_role & int(self._op_role.Backward) and op_role & int( self._op_role.Loss) + def _is_forward_op(self, op): + return self._op_role_key in op.attr_names and ( + int(op.attr(self._op_role_key)) == int(self._op_role.Forward)) + def _is_backward_op(self, op): return self._op_role_key in op.attr_names and ( int(op.attr(self._op_role_key)) & int(self._op_role.Backward)) @@ -4705,10 +4709,6 @@ def _check_validation(self, block): int(self._op_role.Optimize), int(self._op_role.Backward) | int(self._op_role.Loss), ] - pre_stage_id = None - decrease_flag = False - in_optimize = False - in_forward = True for op in block.ops: if not op._has_kernel(op.type): assert op.type == "conditional_block" and ( @@ -4724,10 +4724,6 @@ def _check_validation(self, block): op_role, op.type, valid_op_role_value) - if int(op_role) == int(self._op_role.Optimize): - in_optimize = True - if int(op_role) == int(self._op_role.Backward): - in_forward = False assert op.has_attr(self._op_device_key), ( "op ({}) has no {} attribute.".format(op.type, @@ -4739,7 +4735,6 @@ def _check_validation(self, block): if device == f"{self._device}:all": continue dev_type = device.split(':')[0] - stage_id = int(device.split(':')[1]) assert dev_type == "gpu" or dev_type == 'npu', ( "Now only gpu and npu devices are supported " "for pipeline parallelism.") @@ -4747,28 +4742,6 @@ def _check_validation(self, block): if device not in device_list: device_list.append(device) - if not in_optimize: - if pre_stage_id is not None: - interval = stage_id - pre_stage_id - # assert abs(interval) <= 1, \ - # "The stage interval of two consecutive ops in the pipeline must be < = 1," \ - # "but the interval of op={} and prev op is {}".format(op, interval) - - # stage must be in order, such as Forward(0 1 2 3 4), Backward(4 3 2 1 0) - # if stage is unordered, such as Forward(0 1 2 3 4 3 4), will report error - if in_forward: - pass - # assert interval >= 0, \ - # "Pipeline stage must be sequential increment in Forward, prev_stage={}, " \ - # "please check the stage of op={}".format(pre_stage_id, op) - else: - # FIXME(wangxi): recompute check failed - pass - #assert interval <=0, \ - # "Pipeline stage must be sequential decrement in Backward, prev_stage={}, " \ - # "please check the stage of op={}".format(pre_stage_id, op) - pre_stage_id = stage_id - return device_list def _insert_sendrecv_ops_for_boundaries(self, block): @@ -5007,9 +4980,25 @@ def _insert_send_recv(cur_id, prev_id): "Now only 'F-then-B' and '1F1B' are supported." "The given value is {}.".format(self.schedule_mode)) - _insert_send_recv( - int(cur_device.split(':')[1]), - int(prev_device.split(':')[1])) + cur_stage = int(cur_device.split(':')[1]) + prev_stage = int(prev_device.split(':')[1]) + + is_forward = self._is_forward_op(op) + is_backward = self._is_backward_op(op) + assert is_forward or is_backward, \ + 'send/recv in pipeline should only be inserted in forward or backward,' \ + 'please check the op_role of op={}'.format(op) + + if is_forward: + assert prev_stage < cur_stage, \ + "In forward, send/recv can only be passed forward, but now " \ + "prev_stage={} great than cur_stage={}, please check op_device of op={}".format(prev_stage, cur_stage, op) + elif is_backward: + assert prev_stage > cur_stage, \ + "In backward, send/recv can only be passed backward, but now " \ + "prev_stage={} less than cur_stage={}, please check op_device of op={}".format(prev_stage, cur_stage, op) + + _insert_send_recv(cur_stage, prev_stage) block._sync_with_cpp() def _insert_loss_scale(self, block): From 101f2f997f215b6f66214c3a46572d8e0a1666d0 Mon Sep 17 00:00:00 2001 From: WangXi Date: Fri, 13 Aug 2021 13:28:53 +0800 Subject: [PATCH 4/5] unuse mp send/recv if numel cannot divisble by mp_degree --- python/paddle/fluid/optimizer.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 65accc25361fb2..f87fd21e3b177b 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4865,9 +4865,9 @@ def _insert_send_recv(cur_id, prev_id): var_shape[0] = self.micro_batch_size if var_shape[ 0] < 0 else var_shape[0] - numel = np.prod(var.shape) - assert numel % self.mp_degree == 0, \ - "The numel={} must be divisible by mp_degree={}".format(numel, self.mp_degree) + numel = np.prod(var_shape) + use_mp = (self.mp_degree > 1) and ( + numel % self.mp_degree == 0) if 'subprog' in var.name: # For recompute, if the checkpoints var is layer_norm_6.tmp_2 @@ -4906,8 +4906,7 @@ def _insert_send_recv(cur_id, prev_id): extra_index_info['index'] += 1 block._insert_op_without_sync( index=index + extra_index_info['index'], - type='send_v2' - if self.mp_degree == 1 else 'partial_send', + type='send_v2' if not use_mp else 'partial_send', inputs={'X': var}, attrs={ self._op_device_key: prev_dev, @@ -4943,8 +4942,7 @@ def _insert_send_recv(cur_id, prev_id): extra_index_info['index'] += 1 block._insert_op_without_sync( index=index + extra_index_info['index'], - type='recv_v2' - if self.mp_degree == 1 else 'partial_recv', + type='recv_v2' if not use_mp else 'partial_recv', outputs={'Out': [var]}, attrs={ 'out_shape': var_shape, @@ -4959,7 +4957,7 @@ def _insert_send_recv(cur_id, prev_id): 'id': self.mp_rank, }) extra_index_info['index'] += 1 - if self.mp_degree > 1: + if use_mp: block._insert_op_without_sync( index=index + extra_index_info['index'], type='partial_allgather', From 00404fbe3dceb22dd7b5df905e61513937ef4b8d Mon Sep 17 00:00:00 2001 From: WangXi Date: Fri, 13 Aug 2021 15:45:17 +0800 Subject: [PATCH 5/5] fix recompute check --- python/paddle/fluid/optimizer.py | 43 ++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index f87fd21e3b177b..3cb6d24c86faf2 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4795,6 +4795,25 @@ def _insert_sendrecv_ops_for_boundaries(self, block): device_type = cur_device.split(':')[0] + ':' + def _check_stage(cur_id, prev_id): + # check send/recv stage valid + is_forward = self._is_forward_op(op) + is_backward = self._is_backward_op(op) + assert is_forward or is_backward, \ + 'send/recv in pipeline should only be inserted in forward or backward,' \ + 'please check the op_role of op={}'.format(op) + + if is_forward: + assert prev_id < cur_id, \ + "In forward, send/recv can only be passed forward, but now " \ + "prev_stage={} great than cur_stage={}, please check op_device of op={}".format( + prev_id, cur_id, op) + elif is_backward: + assert prev_id > cur_id, \ + "In backward, send/recv can only be passed backward, but now " \ + "prev_stage={} less than cur_stage={}, please check op_device of op={}".format( + prev_id, cur_id, op) + def _insert_send_recv(cur_id, prev_id): cur_dev = device_type + str(cur_id) prev_dev = device_type + str(prev_id) @@ -4894,6 +4913,8 @@ def _insert_send_recv(cur_id, prev_id): extra_index_info['index'] += 1 return + _check_stage(cur_id, prev_id) + block._insert_op_without_sync( index=index + extra_index_info['index'], type='c_sync_calc_stream', @@ -4978,25 +4999,9 @@ def _insert_send_recv(cur_id, prev_id): "Now only 'F-then-B' and '1F1B' are supported." "The given value is {}.".format(self.schedule_mode)) - cur_stage = int(cur_device.split(':')[1]) - prev_stage = int(prev_device.split(':')[1]) - - is_forward = self._is_forward_op(op) - is_backward = self._is_backward_op(op) - assert is_forward or is_backward, \ - 'send/recv in pipeline should only be inserted in forward or backward,' \ - 'please check the op_role of op={}'.format(op) - - if is_forward: - assert prev_stage < cur_stage, \ - "In forward, send/recv can only be passed forward, but now " \ - "prev_stage={} great than cur_stage={}, please check op_device of op={}".format(prev_stage, cur_stage, op) - elif is_backward: - assert prev_stage > cur_stage, \ - "In backward, send/recv can only be passed backward, but now " \ - "prev_stage={} less than cur_stage={}, please check op_device of op={}".format(prev_stage, cur_stage, op) - - _insert_send_recv(cur_stage, prev_stage) + _insert_send_recv( + int(cur_device.split(':')[1]), + int(prev_device.split(':')[1])) block._sync_with_cpp() def _insert_loss_scale(self, block):