From 0caf4168da1ffddf21f0ea3a7b195c5fbe4d7ef4 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Tue, 14 Sep 2021 09:00:19 +0800 Subject: [PATCH 01/18] add is_distributed to var __str__ method --- python/paddle/fluid/framework.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 92afe0fdaff4d8..7041e8f9e4a47b 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1211,6 +1211,8 @@ def _to_readable_code(self): var_str = "{name} : {type}.shape{shape}.dtype({dtype}).stop_gradient({stop_gradient})".\ format(name=self.name, type=type_str, shape=self.shape, dtype=dtype_str, stop_gradient=self.stop_gradient) + if hasattr(self, 'is_distributed'): + var_str += ".is_distributed({})".format(self.is_distributed) else: var_str = "{name} : {type})".\ format(name=self.name, type=type_str) From 89ee7493f2c966458242530117e839979b7dffdd Mon Sep 17 00:00:00 2001 From: liuyuang Date: Tue, 14 Sep 2021 09:20:25 +0800 Subject: [PATCH 02/18] keep is_distributed info when pp copies vars --- python/paddle/fluid/optimizer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index ed351dcbefdbcc..3040ecfcda07ac 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4382,6 +4382,8 @@ def _create_vars(self, block, ori_block): else: dest_var = block._clone_variable(source_var, False) dest_var.stop_gradient = source_var.stop_gradient + if hasattr(source_var, 'is_distributed'): + dest_var.is_distributed = source_var.is_distributed # When use with sharding, allreduce_sum and allreduce_max # used for global gradient clip and amp will be added by sharding. op_idx += 1 From b0457d79b63c4f6e173328d16c608caebc1b98c0 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Tue, 14 Sep 2021 09:32:06 +0800 Subject: [PATCH 03/18] is_distributed info pass to grad merge var --- python/paddle/fluid/optimizer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 3040ecfcda07ac..2cd4d84adea7b8 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4550,6 +4550,8 @@ def _create_var(self, block, ref_var, name, dtype=None): is_data=ref_var.is_data, need_check_feed=ref_var.desc.need_check_feed()) new_var.stop_gradient = ref_var.stop_gradient + if hasattr(ref_var, 'is_distributed'): + new_var.is_distributed = ref_var.is_distributed return new_var def _strip_grad_suffix(self, name): @@ -5211,6 +5213,8 @@ def _insert_accumulate_gradients_with_fuse(self, main_block, fp16, persistable=True, stop_gradient=False) real_param = main_block.var(param) + if hasattr(real_param, 'is_distributed'): + merged_grad_var.is_distributed = real_param.is_distributed tmp_size = self._get_var_size(real_grad) # two strategies for splitting the grad # 1. the current segment's size reach the user defined grad_size_in_MB From 7b10da89499277fc3afb152e8a6d0275a1122969 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Tue, 14 Sep 2021 11:05:59 +0800 Subject: [PATCH 04/18] prune mp's duplicated var during gradient clip --- .../sharding/gradient_clip_helper.py | 34 +++++++++++++++++-- .../meta_optimizers/sharding_optimizer.py | 3 +- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index 3580e85fc89c19..155d55e654a075 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -142,13 +142,43 @@ def prune_gradient_clip(self, block, shard, ring_ids): return # TODO (JZ-LIANG) revise this for uniform mixed parallelism - def sync_global_norm(self, block, ring_ids): + def sync_global_norm(self, block, ring_ids, mp_rank): """ prune gradient_clip related ops for params that not belong to cur shard prune: square, reduce_sum, elementwise_mul keep: sum, sqrt, elementwise_max, elementwise_div """ - # FIXME(wangxi): mp should prune duplicated param_grads + removed_op_idx = [] + removed_tmp_var = [] + for idx, op in list(enumerate(block.ops)): + if not self._is_gradient_clip_op(op): + continue + if op.type == 'sum': + break + for input_name in op.input_arg_names: + input_var = block.var(input_name) + if hasattr(input_var, 'is_distributed') \ + and not input_var.is_distributed and mp_rank != 0: + removed_op_idx.append(idx) + removed_tmp_var.extend(op.output_arg_names) + + for idx, op in reversed(list(enumerate(block.ops))): + if not self._is_gradient_clip_op(op): + continue + if idx in removed_op_idx: + block._remove_op(idx, sync=False) + continue + + for idx, op in list(enumerate(block.ops)): + if not self._is_gradient_clip_op(op): + continue + if op.type == 'sum' and mp_rank != 0: + reserved_vars = [] + for input_name in op.input_arg_names: + if input_name not in removed_tmp_var: + reserved_vars.append(input_name) + op.desc.set_input("X", reserved_vars) + for idx, op in reversed(list(enumerate(block.ops))): if not self._is_gradient_clip_op(op): continue diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 1f96ab07d60a84..f14f1e06624028 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -435,7 +435,6 @@ def _adapt_amp_clip_without_sharding(self): main_block = self._main_program.global_block() startup_block = self._startup_program.global_block() - # FIXME(wangxi): mp should prune duplicated param_grads when calc # amp inf_var & clip global_norm_var rings = [self.mp_ring_id, self.pp_ring_id] @@ -446,7 +445,7 @@ def _adapt_amp_clip_without_sharding(self): gradientclip_helper = GradientClipHelper(None) gradientclip_helper.sync_global_norm( - main_block, [self.mp_ring_id, self.pp_ring_id]) + main_block, [self.mp_ring_id, self.pp_ring_id], self.mp_rank) def _insert_loss_grad_scale_op(self): main_block = self._main_program.global_block() From 02117b087f2cdcd1e68dd4acffd3a1981a18a4e0 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Tue, 14 Sep 2021 12:06:29 +0800 Subject: [PATCH 05/18] remove framework modification --- python/paddle/fluid/framework.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 7041e8f9e4a47b..92afe0fdaff4d8 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1211,8 +1211,6 @@ def _to_readable_code(self): var_str = "{name} : {type}.shape{shape}.dtype({dtype}).stop_gradient({stop_gradient})".\ format(name=self.name, type=type_str, shape=self.shape, dtype=dtype_str, stop_gradient=self.stop_gradient) - if hasattr(self, 'is_distributed'): - var_str += ".is_distributed({})".format(self.is_distributed) else: var_str = "{name} : {type})".\ format(name=self.name, type=type_str) From 0cca490b0dab7c1f6858bfe6c3af093413566479 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Tue, 14 Sep 2021 16:42:32 +0800 Subject: [PATCH 06/18] extract the copy attr method --- python/paddle/fluid/optimizer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 2cd4d84adea7b8..737b53f6ddcc20 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4382,8 +4382,7 @@ def _create_vars(self, block, ori_block): else: dest_var = block._clone_variable(source_var, False) dest_var.stop_gradient = source_var.stop_gradient - if hasattr(source_var, 'is_distributed'): - dest_var.is_distributed = source_var.is_distributed + self._clone_var_attr(dest_var, source_var) # When use with sharding, allreduce_sum and allreduce_max # used for global gradient clip and amp will be added by sharding. op_idx += 1 @@ -4550,10 +4549,13 @@ def _create_var(self, block, ref_var, name, dtype=None): is_data=ref_var.is_data, need_check_feed=ref_var.desc.need_check_feed()) new_var.stop_gradient = ref_var.stop_gradient - if hasattr(ref_var, 'is_distributed'): - new_var.is_distributed = ref_var.is_distributed + self._clone_var_attr(new_var, ref_var) return new_var + def _clone_var_attr(self, dest, src): + if hasattr(src, 'is_distributed'): + dest.is_distributed = src.is_distributed + def _strip_grad_suffix(self, name): """ Strip the grad suffix from the given variable name @@ -5213,8 +5215,7 @@ def _insert_accumulate_gradients_with_fuse(self, main_block, fp16, persistable=True, stop_gradient=False) real_param = main_block.var(param) - if hasattr(real_param, 'is_distributed'): - merged_grad_var.is_distributed = real_param.is_distributed + self._clone_var_attr(merged_grad_var, real_param) tmp_size = self._get_var_size(real_grad) # two strategies for splitting the grad # 1. the current segment's size reach the user defined grad_size_in_MB From f7f73cc133bade7215217c9fc913f52f480f576e Mon Sep 17 00:00:00 2001 From: liuyuang Date: Tue, 14 Sep 2021 16:47:57 +0800 Subject: [PATCH 07/18] deteremine global clip or not --- .../sharding/gradient_clip_helper.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index 155d55e654a075..467d6c89acb1db 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -148,6 +148,17 @@ def sync_global_norm(self, block, ring_ids, mp_rank): prune: square, reduce_sum, elementwise_mul keep: sum, sqrt, elementwise_max, elementwise_div """ + is_clip_grad_by_global_norm = False + for idx, op in list(enumerate(block.ops)): + if not self._is_gradient_clip_op(op): + continue + if op.type == 'sum': + is_clip_grad_by_global_norm = True + break + if not is_clip_grad_by_global_norm: + # TODO(Yuang Liu): need some extra handles when clip_grad_norm for mp + return + removed_op_idx = [] removed_tmp_var = [] for idx, op in list(enumerate(block.ops)): @@ -157,8 +168,8 @@ def sync_global_norm(self, block, ring_ids, mp_rank): break for input_name in op.input_arg_names: input_var = block.var(input_name) - if hasattr(input_var, 'is_distributed') \ - and not input_var.is_distributed and mp_rank != 0: + if mp_rank != 0 and (not hasattr(input_var, 'is_distributed') or + input_var.is_distributed): removed_op_idx.append(idx) removed_tmp_var.extend(op.output_arg_names) From ea117c8ccd86912d4f88328a74f23f3dd6f149ff Mon Sep 17 00:00:00 2001 From: liuyuang Date: Wed, 15 Sep 2021 09:28:18 +0800 Subject: [PATCH 08/18] add note, bug fix --- .../meta_optimizers/sharding/gradient_clip_helper.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index 467d6c89acb1db..10317a307c79d8 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -168,8 +168,14 @@ def sync_global_norm(self, block, ring_ids, mp_rank): break for input_name in op.input_arg_names: input_var = block.var(input_name) - if mp_rank != 0 and (not hasattr(input_var, 'is_distributed') or - input_var.is_distributed): + # NOTE: when mp_degree > 1, some vars will be split into each mp rank. + # However, there still some vars such as Scale, Bias are not split. + # Those not be split vars should only be counted once during grad clip + # by global norm. Those vars either doesn't have is_distributed attr + # or the is_distributed attr has been set as False. + # Therefore, we prune those duplicated vars for grad clip. + if mp_rank > 1 and (not (hasattr(input_var, 'is_distributed') + and input_var.is_distributed)): removed_op_idx.append(idx) removed_tmp_var.extend(op.output_arg_names) @@ -178,7 +184,6 @@ def sync_global_norm(self, block, ring_ids, mp_rank): continue if idx in removed_op_idx: block._remove_op(idx, sync=False) - continue for idx, op in list(enumerate(block.ops)): if not self._is_gradient_clip_op(op): @@ -189,6 +194,7 @@ def sync_global_norm(self, block, ring_ids, mp_rank): if input_name not in removed_tmp_var: reserved_vars.append(input_name) op.desc.set_input("X", reserved_vars) + break for idx, op in reversed(list(enumerate(block.ops))): if not self._is_gradient_clip_op(op): From 644cc3e2aed5359e901ee3739c36936a6c5fb81a Mon Sep 17 00:00:00 2001 From: liuyuang Date: Wed, 15 Sep 2021 09:58:14 +0800 Subject: [PATCH 09/18] update file --- .../fleet/meta_optimizers/sharding/gradient_clip_helper.py | 3 +++ python/paddle/fluid/optimizer.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index 10317a307c79d8..5e9e15a22f0c71 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -185,6 +185,9 @@ def sync_global_norm(self, block, ring_ids, mp_rank): if idx in removed_op_idx: block._remove_op(idx, sync=False) + for var_name in removed_tmp_var: + block._remove_var(var_name, sync=False) + for idx, op in list(enumerate(block.ops)): if not self._is_gradient_clip_op(op): continue diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 737b53f6ddcc20..709b36ed8e32b2 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4381,7 +4381,6 @@ def _create_vars(self, block, ori_block): persistable=source_var.persistable) else: dest_var = block._clone_variable(source_var, False) - dest_var.stop_gradient = source_var.stop_gradient self._clone_var_attr(dest_var, source_var) # When use with sharding, allreduce_sum and allreduce_max # used for global gradient clip and amp will be added by sharding. @@ -4548,11 +4547,11 @@ def _create_var(self, block, ref_var, name, dtype=None): persistable=ref_var.persistable, is_data=ref_var.is_data, need_check_feed=ref_var.desc.need_check_feed()) - new_var.stop_gradient = ref_var.stop_gradient self._clone_var_attr(new_var, ref_var) return new_var def _clone_var_attr(self, dest, src): + dest.stop_gradient = src.stop_gradient if hasattr(src, 'is_distributed'): dest.is_distributed = src.is_distributed @@ -5215,7 +5214,8 @@ def _insert_accumulate_gradients_with_fuse(self, main_block, fp16, persistable=True, stop_gradient=False) real_param = main_block.var(param) - self._clone_var_attr(merged_grad_var, real_param) + if hasattr(real_param, 'is_distributed'): + merged_grad_var.is_distributed = real_param.is_distributed tmp_size = self._get_var_size(real_grad) # two strategies for splitting the grad # 1. the current segment's size reach the user defined grad_size_in_MB From 45daa5740ec63254b4477980123531f814c3d26e Mon Sep 17 00:00:00 2001 From: liuyuang Date: Wed, 15 Sep 2021 10:52:41 +0800 Subject: [PATCH 10/18] bug fix --- .../fleet/meta_optimizers/sharding/gradient_clip_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index 5e9e15a22f0c71..37c6bdc49f1cd9 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -174,7 +174,7 @@ def sync_global_norm(self, block, ring_ids, mp_rank): # by global norm. Those vars either doesn't have is_distributed attr # or the is_distributed attr has been set as False. # Therefore, we prune those duplicated vars for grad clip. - if mp_rank > 1 and (not (hasattr(input_var, 'is_distributed') + if mp_rank > 0 and (not (hasattr(input_var, 'is_distributed') and input_var.is_distributed)): removed_op_idx.append(idx) removed_tmp_var.extend(op.output_arg_names) From 86268d8c5cb4864a2e8b4dd0ad8917d9deae8f30 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Wed, 15 Sep 2021 15:59:50 +0800 Subject: [PATCH 11/18] replace list with set --- .../sharding/gradient_clip_helper.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index 37c6bdc49f1cd9..9236f9f70c331e 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -159,8 +159,8 @@ def sync_global_norm(self, block, ring_ids, mp_rank): # TODO(Yuang Liu): need some extra handles when clip_grad_norm for mp return - removed_op_idx = [] - removed_tmp_var = [] + removed_op_idx = set() + removed_tmp_var = set() for idx, op in list(enumerate(block.ops)): if not self._is_gradient_clip_op(op): continue @@ -174,10 +174,11 @@ def sync_global_norm(self, block, ring_ids, mp_rank): # by global norm. Those vars either doesn't have is_distributed attr # or the is_distributed attr has been set as False. # Therefore, we prune those duplicated vars for grad clip. - if mp_rank > 0 and (not (hasattr(input_var, 'is_distributed') - and input_var.is_distributed)): - removed_op_idx.append(idx) - removed_tmp_var.extend(op.output_arg_names) + if mp_rank >= 1 and (not (hasattr(input_var, 'is_distributed') + and input_var.is_distributed)): + removed_op_idx.add(idx) + for output_name in op.output_arg_names: + removed_tmp_var.add(output_name) for idx, op in reversed(list(enumerate(block.ops))): if not self._is_gradient_clip_op(op): @@ -191,7 +192,7 @@ def sync_global_norm(self, block, ring_ids, mp_rank): for idx, op in list(enumerate(block.ops)): if not self._is_gradient_clip_op(op): continue - if op.type == 'sum' and mp_rank != 0: + if op.type == 'sum' and mp_rank >= 1: reserved_vars = [] for input_name in op.input_arg_names: if input_name not in removed_tmp_var: From c5042ad8b9abd9d9e924122bb1d7ed58b4939c18 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Wed, 15 Sep 2021 16:35:10 +0800 Subject: [PATCH 12/18] fix a potential bug --- .../sharding/gradient_clip_helper.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index 9236f9f70c331e..b8d3b8541164dd 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -197,7 +197,26 @@ def sync_global_norm(self, block, ring_ids, mp_rank): for input_name in op.input_arg_names: if input_name not in removed_tmp_var: reserved_vars.append(input_name) - op.desc.set_input("X", reserved_vars) + if len(reserved_vars) > 0: + op.desc.set_input("X", reserved_vars) + else: + # If all input of sum op should be removed, then remove the sum op. + # And set the output's value of sum to 0. + sum_rst_var = block.var(op.output_arg_names[0]) + namescope = op.attr("op_namescope") + block._remove_op(idx, sync=False) + fill_constant_op = block._insert_op_without_sync( + idx, + type='fill_constant', + inputs={}, + outputs={'Out': sum_rst_var}, + attrs={ + 'shape': sum_rst_var.shape, + 'dtype': sum_rst_var.dtype, + 'value': 0.0, + OP_ROLE_KEY: op.attr(OP_ROLE_KEY) + }) + fill_constant_op._set_attr('op_namescope', namescope) break for idx, op in reversed(list(enumerate(block.ops))): From a0457217ef3211a0b222ad79afd4d1c18be86e20 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Thu, 16 Sep 2021 09:24:50 +0800 Subject: [PATCH 13/18] fix no allreduce prob after pruning sum op --- .../sharding/gradient_clip_helper.py | 38 +++++++++++-------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index b8d3b8541164dd..d3d3a6892c2d42 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -214,9 +214,10 @@ def sync_global_norm(self, block, ring_ids, mp_rank): 'shape': sum_rst_var.shape, 'dtype': sum_rst_var.dtype, 'value': 0.0, - OP_ROLE_KEY: op.attr(OP_ROLE_KEY) + OP_ROLE_KEY: OpRole.Optimize }) fill_constant_op._set_attr('op_namescope', namescope) + self._insert_allreduce(block, ring_ids, idx, sum_rst_var) break for idx, op in reversed(list(enumerate(block.ops))): @@ -225,19 +226,24 @@ def sync_global_norm(self, block, ring_ids, mp_rank): if op.type == "sum": sum_res = op.desc.output_arg_names()[0] - for ring_id in ring_ids: - if ring_id == -1: continue - - idx = idx + 1 - block._insert_op_without_sync( - idx, - type='c_allreduce_sum', - inputs={'X': sum_res}, - outputs={'Out': sum_res}, - attrs={ - 'ring_id': ring_id, - 'op_namescope': "/gradient_clip_model_parallelism", - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Optimize, - }) + self._insert_allreduce(block, ring_ids, idx, sum_res) return + + @staticmethod + def _insert_allreduce(block, ring_ids, idx, var): + for ring_id in ring_ids: + if ring_id == -1: + continue + + idx = idx + 1 + block._insert_op_without_sync( + idx, + type='c_allreduce_sum', + inputs={'X': var}, + outputs={'Out': var}, + attrs={ + 'ring_id': ring_id, + 'op_namescope': "/gradient_clip_model_parallelism", + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize, + }) From 383fb5f70a262a239f61d68232dc3bb62ef0580b Mon Sep 17 00:00:00 2001 From: liuyuang Date: Thu, 16 Sep 2021 09:26:45 +0800 Subject: [PATCH 14/18] early return --- .../fleet/meta_optimizers/sharding/gradient_clip_helper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index d3d3a6892c2d42..4a17a80fb8c16c 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -218,6 +218,7 @@ def sync_global_norm(self, block, ring_ids, mp_rank): }) fill_constant_op._set_attr('op_namescope', namescope) self._insert_allreduce(block, ring_ids, idx, sum_rst_var) + return break for idx, op in reversed(list(enumerate(block.ops))): From 8b2fabcf55336f2de58fda62ed0ea6e417e4efe5 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Thu, 16 Sep 2021 09:29:21 +0800 Subject: [PATCH 15/18] add comment --- .../fleet/meta_optimizers/sharding/gradient_clip_helper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index 4a17a80fb8c16c..0880e44516d82e 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -217,6 +217,7 @@ def sync_global_norm(self, block, ring_ids, mp_rank): OP_ROLE_KEY: OpRole.Optimize }) fill_constant_op._set_attr('op_namescope', namescope) + # insert redundant allreduce to prevent hang self._insert_allreduce(block, ring_ids, idx, sum_rst_var) return break From 75882840e7b337d1328c38651f32301317ddd9c7 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Thu, 16 Sep 2021 10:42:49 +0800 Subject: [PATCH 16/18] fix ci --- .../fleet/meta_optimizers/sharding/gradient_clip_helper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index 0880e44516d82e..a73300dd276248 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -219,7 +219,6 @@ def sync_global_norm(self, block, ring_ids, mp_rank): fill_constant_op._set_attr('op_namescope', namescope) # insert redundant allreduce to prevent hang self._insert_allreduce(block, ring_ids, idx, sum_rst_var) - return break for idx, op in reversed(list(enumerate(block.ops))): From d9b6d50ed63ea3395d609074804262e52fb7788e Mon Sep 17 00:00:00 2001 From: liuyuang Date: Thu, 16 Sep 2021 14:57:19 +0800 Subject: [PATCH 17/18] restruct the logic --- .../sharding/gradient_clip_helper.py | 66 +++++++++---------- 1 file changed, 30 insertions(+), 36 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index a73300dd276248..5d28c2d5cebd92 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -192,44 +192,38 @@ def sync_global_norm(self, block, ring_ids, mp_rank): for idx, op in list(enumerate(block.ops)): if not self._is_gradient_clip_op(op): continue - if op.type == 'sum' and mp_rank >= 1: - reserved_vars = [] - for input_name in op.input_arg_names: - if input_name not in removed_tmp_var: - reserved_vars.append(input_name) - if len(reserved_vars) > 0: - op.desc.set_input("X", reserved_vars) - else: - # If all input of sum op should be removed, then remove the sum op. - # And set the output's value of sum to 0. - sum_rst_var = block.var(op.output_arg_names[0]) - namescope = op.attr("op_namescope") - block._remove_op(idx, sync=False) - fill_constant_op = block._insert_op_without_sync( - idx, - type='fill_constant', - inputs={}, - outputs={'Out': sum_rst_var}, - attrs={ - 'shape': sum_rst_var.shape, - 'dtype': sum_rst_var.dtype, - 'value': 0.0, - OP_ROLE_KEY: OpRole.Optimize - }) - fill_constant_op._set_attr('op_namescope', namescope) - # insert redundant allreduce to prevent hang - self._insert_allreduce(block, ring_ids, idx, sum_rst_var) + if op.type == 'sum': + # If mp_rank == 0, no extra handles, just allreduce + # If mp_rank >= 1, some extra handles is needed + sum_rst_var = block.var(op.output_arg_names[0]) + if mp_rank >= 1: + reserved_vars = [] + for input_name in op.input_arg_names: + if input_name not in removed_tmp_var: + reserved_vars.append(input_name) + + if len(reserved_vars) > 0: + op.desc.set_input("X", reserved_vars) + else: + # If all input of sum op should be removed, then remove the sum op. + # And set the output's value of sum to 0. + namescope = op.attr("op_namescope") + block._remove_op(idx, sync=False) + fill_constant_op = block._insert_op_without_sync( + idx, + type='fill_constant', + inputs={}, + outputs={'Out': sum_rst_var}, + attrs={ + 'shape': sum_rst_var.shape, + 'dtype': sum_rst_var.dtype, + 'value': 0.0, + OP_ROLE_KEY: OpRole.Optimize + }) + fill_constant_op._set_attr('op_namescope', namescope) + self._insert_allreduce(block, ring_ids, idx, sum_rst_var) break - for idx, op in reversed(list(enumerate(block.ops))): - if not self._is_gradient_clip_op(op): - continue - - if op.type == "sum": - sum_res = op.desc.output_arg_names()[0] - self._insert_allreduce(block, ring_ids, idx, sum_res) - return - @staticmethod def _insert_allreduce(block, ring_ids, idx, var): for ring_id in ring_ids: From 7dd5b95ab1739b1185ff034357d947619b73c9e9 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Thu, 16 Sep 2021 15:31:44 +0800 Subject: [PATCH 18/18] update ut --- .../test_fleet_sharding_meta_optimizer.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index 1dd368f0848c1d..6b0a7b79c232cc 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -658,6 +658,33 @@ def test_hybrid_with_mp_pp_amp_gclip(self): 'c_gen_nccl_id', 'c_comm_init' ]) + self.assertEqual(main_prog_op_types, [ + 'partial_recv', 'partial_allgather', 'cast', 'cast', 'mul', 'cast', + 'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast', + 'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast', + 'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast', + 'elementwise_add', 'softmax', 'cast', 'cross_entropy2', 'mean', + 'elementwise_mul', 'fill_constant', 'elementwise_mul_grad', + 'mean_grad', 'cross_entropy_grad2', 'cast', 'softmax_grad', + 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', + 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', + 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', + 'elementwise_add_grad', 'mul_grad', 'cast', 'c_sync_calc_stream', + 'partial_send', 'fill_constant', 'cast', 'sum', 'fill_constant', + 'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant', + 'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant', + 'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant', + 'cast', 'sum', 'c_sync_comm_stream', 'check_finite_and_unscale', + 'cast', 'c_allreduce_max', 'c_allreduce_max', 'cast', + 'update_loss_scaling', 'fill_constant', 'c_allreduce_sum', + 'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max', + 'elementwise_div', 'elementwise_mul', 'elementwise_mul', + 'elementwise_mul', 'elementwise_mul', 'elementwise_mul', + 'elementwise_mul', 'elementwise_mul', 'elementwise_mul', 'momentum', + 'momentum', 'momentum', 'momentum', 'momentum', 'momentum', + 'momentum', 'momentum' + ]) + # pp + mp, partial send recv self.assertIn('partial_recv', main_prog_op_types) self.assertIn('partial_allgather', main_prog_op_types)