From 14b07bae52be7cfd4a041aeb6f1a48b9af693528 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 29 Jan 2021 17:36:20 +0800 Subject: [PATCH 01/24] Recompute fixed bug BOOL VarType --- .../fleet/meta_optimizers/sharding/utils.py | 137 ++++++++++-------- .../meta_optimizers/sharding_optimizer.py | 64 ++++++-- python/paddle/fluid/backward.py | 36 +++-- 3 files changed, 155 insertions(+), 82 deletions(-) mode change 100644 => 100755 python/paddle/fluid/backward.py diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index ad1cd4f60826bb..eb5767ec4d3436 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -28,21 +28,24 @@ def check_broadcast(block): if the broadcasted var has a fill_constant op, the fill_constant op should stay forward before the broadcast op, and before a sync_calc op. Otherwise, raise error. + + should ignore and skip broadcast_op of inner_parallelism (e.g. Megatron) """ broadcast_vars = {} for idx, op in enumerate(block.ops): if op.type == "c_broadcast": - var_name = op.desc.input_arg_names()[0] - if "@BroadCast" in var_name: - if var_name in broadcast_vars: - raise ValueError("var_name areadly exist: {}" - "the old pos is {}, the new pos is {}". - format(var_name, broadcast_vars[var_name][ - "broadcast_pos"], idx)) - broadcast_vars[var_name] = { - "fill_constant_pos": -1, - "broadcast_pos": idx, - } + if op.all_attrs()["use_calc_stream"] == False: + var_name = op.desc.input_arg_names()[0] + if "@BroadCast" in var_name: + if var_name in broadcast_vars: + raise ValueError("var_name areadly exist: {}" + "the old pos is {}, the new pos is {}". + format(var_name, broadcast_vars[ + var_name]["broadcast_pos"], idx)) + broadcast_vars[var_name] = { + "fill_constant_pos": -1, + "broadcast_pos": idx, + } for idx, op in enumerate(block.ops): if op.type == "fill_constant": @@ -61,14 +64,15 @@ def check_broadcast(block): last_sync_calc_op_idx = idx continue if op.type == "c_broadcast": - var_name = op.desc.input_arg_names()[0] - if "@BroadCast" in var_name: - if broadcast_vars[var_name]["fill_constant_pos"] != -1: - assert (last_sync_calc_op_idx != -1) - assert (broadcast_vars[var_name]["fill_constant_pos"] < - last_sync_calc_op_idx) - assert (last_sync_calc_op_idx < idx) - continue + if op.all_attrs()["use_calc_stream"] == False: + var_name = op.desc.input_arg_names()[0] + if "@BroadCast" in var_name: + if broadcast_vars[var_name]["fill_constant_pos"] != -1: + assert (last_sync_calc_op_idx != -1) + assert (broadcast_vars[var_name]["fill_constant_pos"] < + last_sync_calc_op_idx) + assert (last_sync_calc_op_idx < idx) + continue for input_name in op.desc.input_arg_names(): if input_name in broadcast_vars: assert (broadcast_vars[input_name]["broadcast_pos"] != -1) @@ -78,7 +82,7 @@ def check_broadcast(block): return -def check_allreduce_sum(block, shard, dp_ring_id=-1): +def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): """ the op order should be: grad: @@ -89,32 +93,36 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1): - 4: allreuce_sum_dp (dp_grads) - 5: sync_comm (dp_grads) - 6: op that use Var (dp_grads & sum) + + should ignore and skip allreduce_op of inner_parallelism (e.g. Megatron) """ vars_status = {} dp_grads_status = {} idx_last_grad_allreduce = -1 idx_amp_allreduce = -1 idx_gradient_clip_allreduce = -1 + for idx, op in enumerate(block.ops): if op.type == "c_allreduce_sum": - ring_id = op.desc.attr("ring_id") - var_name = op.desc.input_arg_names()[0] - param = var_name.split("@")[0] + if op.all_attrs()["use_calc_stream"] == False: + ring_id = op.desc.attr("ring_id") + var_name = op.desc.input_arg_names()[0] + param = var_name.split("@")[0] - assert 'sum' in var_name or ("@GRAD" in var_name) - if 'sum' in var_name or (not shard.has_param(param)): - vars_status[var_name] = -1 - else: - dp_grads_status[var_name] = -1 + assert 'sum' in var_name or ("@GRAD" in var_name) + if 'sum' in var_name or (not shard.has_param(param)): + vars_status[var_name] = -1 + else: + dp_grads_status[var_name] = -1 - if ring_id != 0: - assert shard.has_param(param) - assert ring_id == dp_ring_id + if ring_id != sharding_ring_id: + assert shard.has_param(param) + assert ring_id == dp_ring_id - if "sum" in var_name: - idx_amp_allreduce = idx - elif "@GRAD": - idx_last_grad_allreduce = idx + if "sum" in var_name: + idx_amp_allreduce = idx + elif "@GRAD": + idx_last_grad_allreduce = idx if op.type == "c_allreduce_max": idx_gradient_clip_allreduce = idx @@ -130,36 +138,38 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1): dp_grads_status[var_name] = 1 elif op.type == "c_allreduce_sum": - var_name = op.desc.input_arg_names()[0] - ring_id = op.desc.attr("ring_id") - if ring_id == 0: - if var_name in vars_status: - _status = vars_status[var_name] - else: - _status = dp_grads_status[var_name] - if _status == -1: - raise ValueError("{} is not generated, but you are" - "trying to all-reduce it".format(var_name)) - if _status == 0: - raise ValueError("There should be a sync_calc op " - "after generate Var: {} and before the" - "c_allreduce_sum op".format(var_name)) - assert (_status == 1) - if var_name in vars_status: - vars_status[var_name] = 2 + if op.all_attrs()["use_calc_stream"] == False: + var_name = op.desc.input_arg_names()[0] + ring_id = op.desc.attr("ring_id") + if ring_id == sharding_ring_id: + if var_name in vars_status: + _status = vars_status[var_name] + else: + _status = dp_grads_status[var_name] + if _status == -1: + raise ValueError("{} is not generated, but you are" + "trying to all-reduce it".format( + var_name)) + if _status == 0: + raise ValueError("There should be a sync_calc op " + "after generate Var: {} and before the" + "c_allreduce_sum op".format(var_name)) + assert (_status == 1) + if var_name in vars_status: + vars_status[var_name] = 2 + else: + dp_grads_status[var_name] = 2 else: - dp_grads_status[var_name] = 2 - else: - assert ring_id == dp_ring_id - param = var_name.split("@")[0] - assert shard.has_param(param) - assert dp_grads_status[var_name] == 3 - dp_grads_status[var_name] = 4 + assert ring_id == dp_ring_id + param = var_name.split("@")[0] + assert shard.has_param(param) + assert dp_grads_status[var_name] == 3 + dp_grads_status[var_name] = 4 elif op.type == "c_sync_comm_stream": var_name = op.desc.input_arg_names()[0] ring_id = op.desc.attr("ring_id") - if ring_id == 0: + if ring_id == sharding_ring_id: for var_name in op.desc.input_arg_names(): if var_name in vars_status: assert vars_status[var_name] == 2 @@ -428,7 +438,7 @@ def comm_analyse(main_program): count)) -def add_sync_comm(program, dist_strategy): +def add_sync_comm(program, nccl_ids): """ When clone a test prog by clone from the sharding main prog, part of the sync_comm op maybe be pruned by mistake, this function @@ -438,6 +448,9 @@ def add_sync_comm(program, dist_strategy): #NOTE (liangjianzhong): only support one comm stream by now, use more than one # comm streams will cause error. should be revise in future. + assert isinstance( + nccl_ids, list + ), "the second argument of this function should be a list of nccl_ids" block = program.global_block() not_sync_vars = set([]) for op in block.ops: @@ -448,7 +461,7 @@ def add_sync_comm(program, dist_strategy): for input_name in op.desc.input_arg_names(): not_sync_vars.remove(input_name) if not_sync_vars: - for nccl_id in range(dist_strategy.nccl_comm_num): + for nccl_id in nccl_ids: block.append_op( type='c_sync_comm_stream', inputs={'X': list(not_sync_vars)}, diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index a7f704361d31af..96618df9e464bb 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -51,6 +51,10 @@ def __init__(self, optimizer): self._reduced_grads_to_param = {} self._shard = Shard() + # use sharding as outer parallelism (e.g. inner:Megatron & outer sharding) + self._as_outer_parallelism = True + self._inner_parallelism_size = 4 + def _can_apply(self): if not self.role_maker._is_collective: return False @@ -105,8 +109,10 @@ def minimize_impl(self, startup_block._sync_with_cpp() # step4: insert reduce_sum for grad - insert_scale_loss_grad_ops( - main_block, scale=1.0 / self.role_maker._worker_num()) + grad_scale_coeff = self.role_maker._worker_num() + if self._as_outer_parallelism: + grad_scale_coeff = grad_scale_coeff / self._inner_parallelism_size + insert_scale_loss_grad_ops(main_block, scale=1.0 / grad_scale_coeff) main_block._sync_with_cpp() # step5: remove unneeded ops and vars from block @@ -115,7 +121,8 @@ def minimize_impl(self, # check op dependecy check_broadcast(main_block) - check_allreduce_sum(main_block, self._shard, self.dp_ring_id) + check_allreduce_sum(main_block, self._shard, self.sharding_ring_id, + self.dp_ring_id) self._wait() return optimize_ops, params_grads @@ -459,6 +466,7 @@ def _prune_startup_program(self, block): def _init_comm(self): if self.hybrid_dp: + assert self._as_outer_parallelism == False, "hybrid dp is conflict when using sharding as outer parallelism" self.sharding_group_size = self.user_defined_strategy.sharding_configs[ "sharding_group_size"] self.sharding_ring_id = 0 @@ -488,17 +496,55 @@ def _init_comm(self): logging.info("Using Sharing&DP mode !") else: - self.sharding_ring_id = 0 - self.sharding_rank = self.global_rank - self.sharding_group_size = self.role_maker._worker_num() - self.sharding_group_endpoints = self.endpoints + if self._as_outer_parallelism: + self.sharding_ring_id = 1 + assert self.global_word_size > self._inner_parallelism_size, \ + "global_word_size: {} should be larger than inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size) + assert self.global_word_size % self._inner_parallelism_size == 0, \ + "global_word_size: {} should be divisible to the inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size) + self.sharding_rank = self.global_rank // self._inner_parallelism_size + self.sharding_group_size = self.role_maker._worker_num( + ) // self._inner_parallelism_size + _offset = self.global_rank % self._inner_parallelism_size + self.sharding_group_endpoints = [ + ep for idx, ep in enumerate(self.endpoints) + if idx % self._inner_parallelism_size == _offset + ] + logging.info("Using Sharing as Outer parallelism mode !") + + print( + "init the nccl comm for megatron paramllelism, this should be done in Megatron Metaoptimizer" + ) + partition_idx = self.global_rank // self._inner_parallelism_size + magetron_endpoints = self.endpoints[ + partition_idx * self._inner_parallelism_size:partition_idx * + self._inner_parallelism_size + self._inner_parallelism_size] + magetron_rank = self.global_rank % self._inner_parallelism_size + + self._collective_helper._init_communicator( + program=self._startup_program, + current_endpoint=self.current_endpoint, + endpoints=magetron_endpoints, + rank=magetron_rank, + ring_id=0, + wait_port=True) + logging.info("megatron group size: {}".format( + self._inner_parallelism_size)) + logging.info("megatron rank: {}".format(magetron_rank)) + logging.info("megatron endpoints: {}".format( + magetron_endpoints)) + else: + self.sharding_ring_id = 0 + self.sharding_rank = self.global_rank + self.sharding_group_size = self.role_maker._worker_num() + self.sharding_group_endpoints = self.endpoints + logging.info("Using Sharing alone mode !") + self.dp_ring_id = -1 self.dp_rank = -1 self.dp_group_size = None self.dp_group_endpoints = None - logging.info("Using Sharing alone mode !") - logging.info("global word size: {}".format(self.global_word_size)) logging.info("global rank: {}".format(self.global_rank)) logging.info("sharding group_size: {}".format(self.sharding_group_size)) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py old mode 100644 new mode 100755 index 33e2e387a82758..0b4fa1469d77e3 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -115,7 +115,7 @@ def is_amp_cast(op): updated_min_idx = min_idx while idx_ > pre_segment_end_idx: if is_amp_cast(self.ops[idx_]): - _logger.debug("found amp-cast op: {}, : {}".format(self.ops[ + _logger.info("found amp-cast op: {}, : {}".format(self.ops[ idx_].desc.type(), self.ops[idx_].desc.input_arg_names()[ 0])) updated_min_idx = idx_ @@ -155,7 +155,7 @@ def sort_checkpoints(self, checkpoints_name): sorted_checkpoints = [] for name in checkpoints_name: if name not in self.var_op_deps: - _logger.debug( + _logger.info( "Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program." % name) elif self.var_op_deps[name]["var_as_output_ops"] == []: @@ -784,7 +784,6 @@ def _append_backward_ops_with_checkpoints_( start_idx = 0 pre_segment_end_idx = -1 while True: - _logger.debug("FW op range[0] - [{}]".format(len(ops))) if start_idx >= len(checkpoints_name) - 1: break # min_idx: checkpoint_1' s input op @@ -797,6 +796,9 @@ def _append_backward_ops_with_checkpoints_( min_idx = program_stat._update_segment_start( min_idx, pre_segment_end_idx) segments.append([min_idx, max_idx + 1]) + else: + _logger.info("Could not recompute op range [{}] - [{}] ".format( + min_idx, max_idx + 1)) start_idx += 1 @@ -806,15 +808,15 @@ def _append_backward_ops_with_checkpoints_( recompute_segments = segments for i, (idx1, idx2) in enumerate(recompute_segments): - _logger.debug("recompute segment[{}]".format(i)) - _logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( + _logger.info("recompute segment[{}]".format(i)) + _logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( ), ops[idx1].desc.input_arg_names())) - _logger.debug("segment end op: [{}]: [{}]".format(ops[ + _logger.info("segment end op: [{}]: [{}]".format(ops[ idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names())) - _logger.debug("recompute segment[{}]".format(i)) - _logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( + _logger.info("recompute segment[{}]".format(i)) + _logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( ), ops[idx1].desc.input_arg_names())) - _logger.debug("segment end op: [{}]: [{}]".format(ops[ + _logger.info("segment end op: [{}]: [{}]".format(ops[ idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names())) # 2) go through all forward ops and induct all variables that will be hold in memory @@ -825,9 +827,9 @@ def _append_backward_ops_with_checkpoints_( program_stat.get_out_of_subgraph_vars(segment[0], segment[1])) cross_vars = set(vars_should_be_hold) - set(checkpoints_name) - _logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ + _logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ len(cross_vars), cross_vars)) - _logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ + _logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ len(cross_vars), cross_vars)) # b. output of seed op should be kept in memory @@ -888,6 +890,18 @@ def _append_backward_ops_with_checkpoints_( continue if name not in var_name_dict: var_name_dict[name] = name + var_suffix + + # we should create the rename var in subprog, otherwise its VarType will be BOOL + block.create_var( + name=var_name_dict[name], + shape=block.program.global_block().var(name).shape, + dtype=block.program.global_block().var(name).dtype, + type=block.program.global_block().var(name).type, + persistable=block.program.global_block().var( + name).persistable, + stop_gradient=block.program.global_block().var(name) + .stop_gradient) + # 3.a. add ops in current recompute_segment as forward recomputation ops buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block, vars_in_memory) From e6b489d5f5edf98f41f69f131c713e93fe0d43a2 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Sat, 30 Jan 2021 14:02:40 +0800 Subject: [PATCH 02/24] Sharding support Megatron --- .../framework/distributed_strategy.proto | 2 + .../meta_optimizers/sharding_optimizer.py | 52 +++++++++++-------- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index b36793507f54bf..e735da6501b25a 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -32,6 +32,8 @@ message ShardingConfig { optional float fuse_broadcast_MB = 1 [ default = 32.0 ]; optional bool hybrid_dp = 2 [ default = false ]; optional int32 sharding_group_size = 3 [ default = 8 ]; + optional bool as_outer_parallelism = 4 [ default = false ]; + optional int32 inner_parallelism_size = 5 [ default = 8 ]; } message AMPConfig { diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 96618df9e464bb..07e2dd941f174e 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -39,6 +39,7 @@ def __init__(self, optimizer): "AMPOptimizer", "LarsOptimizer", "LambOptimizer", + "ModelParallelOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] self._main_program = None @@ -52,8 +53,8 @@ def __init__(self, optimizer): self._shard = Shard() # use sharding as outer parallelism (e.g. inner:Megatron & outer sharding) - self._as_outer_parallelism = True - self._inner_parallelism_size = 4 + self._as_outer_parallelism = False + self._inner_parallelism_size = None def _can_apply(self): if not self.role_maker._is_collective: @@ -83,6 +84,11 @@ def minimize_impl(self, "fuse_broadcast_MB"] self.hybrid_dp = self.user_defined_strategy.sharding_configs[ "hybrid_dp"] + self._as_outer_parallelism = self.user_defined_strategy.sharding_configs[ + "as_outer_parallelism"] + self._inner_parallelism_size = int( + self.user_defined_strategy.sharding_configs[ + "inner_parallelism_size"]) if self.inner_opt is None: raise ValueError( @@ -512,27 +518,27 @@ def _init_comm(self): ] logging.info("Using Sharing as Outer parallelism mode !") - print( - "init the nccl comm for megatron paramllelism, this should be done in Megatron Metaoptimizer" - ) - partition_idx = self.global_rank // self._inner_parallelism_size - magetron_endpoints = self.endpoints[ - partition_idx * self._inner_parallelism_size:partition_idx * - self._inner_parallelism_size + self._inner_parallelism_size] - magetron_rank = self.global_rank % self._inner_parallelism_size - - self._collective_helper._init_communicator( - program=self._startup_program, - current_endpoint=self.current_endpoint, - endpoints=magetron_endpoints, - rank=magetron_rank, - ring_id=0, - wait_port=True) - logging.info("megatron group size: {}".format( - self._inner_parallelism_size)) - logging.info("megatron rank: {}".format(magetron_rank)) - logging.info("megatron endpoints: {}".format( - magetron_endpoints)) + # print( + # "init the nccl comm for megatron paramllelism, this should be done in Megatron Metaoptimizer" + # ) + # partition_idx = self.global_rank // self._inner_parallelism_size + # magetron_endpoints = self.endpoints[ + # partition_idx * self._inner_parallelism_size:partition_idx * + # self._inner_parallelism_size + self._inner_parallelism_size] + # magetron_rank = self.global_rank % self._inner_parallelism_size + + # self._collective_helper._init_communicator( + # program=self._startup_program, + # current_endpoint=self.current_endpoint, + # endpoints=magetron_endpoints, + # rank=magetron_rank, + # ring_id=0, + # wait_port=True) + # logging.info("megatron group size: {}".format( + # self._inner_parallelism_size)) + # logging.info("megatron rank: {}".format(magetron_rank)) + # logging.info("megatron endpoints: {}".format( + # magetron_endpoints)) else: self.sharding_ring_id = 0 self.sharding_rank = self.global_rank From c8b0f92e97f6f0f887713efcde7f4d011c055839 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 3 Feb 2021 21:52:43 +0800 Subject: [PATCH 03/24] sharding-megatron support amp, sharidng dp init broadcast --- .../fleet/meta_optimizers/amp_optimizer.py | 3 +- .../meta_optimizers/sharding/fp16_helper.py | 38 ++++++---- .../meta_optimizers/sharding_optimizer.py | 73 ++++++++++++++++++- 3 files changed, 96 insertions(+), 18 deletions(-) mode change 100644 => 100755 python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py mode change 100644 => 100755 python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py diff --git a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py old mode 100644 new mode 100755 index dba3c944f70ab8..858a28e6773f11 --- a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py @@ -57,8 +57,9 @@ def _init_wrapped_opt(self): # add is_distributed to optimize amp, overlap communication and # computation by split the check_finite_and_unscale op. is_distributed = self.role_maker._worker_num() > 1 - if self.user_defined_strategy.sharding: + if self.user_defined_strategy.sharding or self.user_defined_strategy.model_parallel: # FIXME(wangxi). sharding failed when split check_finite_and_unscale + # FIXME(JZ-LIANG). To support Sharding-Megatron-AMP, Megatron should follow Sharding's behavior is_distributed = False self.wrapped_opt._set_distributed(is_distributed) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py old mode 100644 new mode 100755 index 03b36262a4fb1e..a9f1327cb19a01 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py @@ -73,7 +73,7 @@ def remove_cast_op(block, params, segment, offset): @staticmethod def prune_fp16(block, shard, reduced_grads_to_param, ring_id): """ - 1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard + 1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard 2. revise amp inifine grad checking for sharding """ # remove cast @@ -103,6 +103,7 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id): op._rename_input(inf_var_name, inf_var_name + "@sharding") if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: reversed_x = [] + reversed_x_paramname = [] for input_name in op.desc.input('X'): param_name = input_name.strip("@GRAD") if param_name not in shard.global_params: @@ -111,12 +112,19 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id): "be grads, but {} is not a grad".format(input_name)) if shard.has_param(param_name): reversed_x.append(input_name) + reversed_x_paramname.append(param_name) op.desc.set_input('X', reversed_x) op.desc.set_output('Out', reversed_x) + + # the grad checking should take the all and only param in the current shard + to_check_param = set(reversed_x_paramname) + should_check_param = set(shard.global_params).intersection(set([param for param, worker_idx in shard.global_param2device.items() if worker_idx == shard.worker_idx])) + assert to_check_param == should_check_param, "amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(should_check_param - to_check_param, to_check_param - should_check_param) + if update_loss_scaling_op_idx == -1: return inf_var = block.var(inf_var_name) - inf_var_fp32 = block.create_var( + inf_var_int32 = block.create_var( name=inf_var_name + "@cast_int32", shape=inf_var.shape, dtype=core.VarDesc.VarType.INT32) @@ -128,32 +136,34 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id): update_loss_scaling_op_idx, type='cast', inputs={'X': inf_var}, - outputs={'Out': inf_var_fp32}, + outputs={'Out': inf_var_int32}, attrs={ "in_dtype": inf_var.dtype, - "out_dtype": inf_var_fp32.dtype, + "out_dtype": inf_var_int32.dtype, OP_ROLE_KEY: OpRole.Optimize }) - insert_sync_calc_op(block, update_loss_scaling_op_idx + 1, - [inf_var_fp32]) + # this allreduce communication should not overlap with calc + # insert_sync_calc_op(block, update_loss_scaling_op_idx + 1, + # [inf_var_int32]) block._insert_op_without_sync( - update_loss_scaling_op_idx + 2, + update_loss_scaling_op_idx + 1, type='c_allreduce_max', - inputs={'X': inf_var_fp32}, - outputs={'Out': inf_var_fp32}, + inputs={'X': inf_var_int32}, + outputs={'Out': inf_var_int32}, attrs={'ring_id': ring_id, + 'use_calc_stream': True, OP_ROLE_KEY: OpRole.Optimize}) - comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3, - ring_id, [inf_var_fp32]) + # comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3, + # ring_id, [inf_var_int32]) block._insert_op_without_sync( - update_loss_scaling_op_idx + 3 + comm_op_num, + update_loss_scaling_op_idx + 2, type='cast', - inputs={'X': inf_var_fp32}, + inputs={'X': inf_var_int32}, outputs={'Out': inf_var_sharding}, attrs={ - "in_dtype": inf_var_fp32.dtype, + "in_dtype": inf_var_int32.dtype, "out_dtype": inf_var_sharding.dtype, OP_ROLE_KEY: OpRole.Optimize }) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 07e2dd941f174e..b45ab35a635564 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -124,6 +124,8 @@ def minimize_impl(self, # step5: remove unneeded ops and vars from block self._prune_main_program(main_block) self._prune_startup_program(startup_block) + if self.hybrid_dp: + self._initialization_broadcast(startup_program) # check op dependecy check_broadcast(main_block) @@ -147,6 +149,14 @@ def _set_up(self, params_grads): self._startup_program, self.current_endpoint, self.sharding_group_endpoints, self.sharding_rank, self.sharding_ring_id, True) + + # inner & outer model parallelism + if self._as_outer_parallelism: + self._collective_helper._init_communicator( + self._startup_program, self.current_endpoint, + self.mp_group_endpoints, self.mp_rank, + self.mp_group_id, True) + # dp if self.hybrid_dp: self._collective_helper._init_communicator( @@ -247,8 +257,14 @@ def _prune_main_program(self, block): """ weightdecay_helper = WeightDecayHelper() weightdecay_helper.prune_weight_decay(block, self._shard) + # NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism + # group. and each Data Parallelism group should have its own sync of FoundInfinite + FoundInfinite_ring_id = self.sharding_ring_id + if self._as_outer_parallelism: + FoundInfinite_ring_id = self.mp_group_id FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param, - self.sharding_ring_id) + FoundInfinite_ring_id) + gradientclip_helper = GradientClipHelper(self.sharding_ring_id) gradientclip_helper.prune_gradient_clip(block, self._shard) @@ -500,6 +516,13 @@ def _init_comm(self): self.sharding_group_size, self.dp_group_size) + # sharding parallelism is the only model parallelism in the current setting + self.mp_group_id = self.sharding_ring_id + self.mp_rank = self.sharding_rank + self.mp_group_size = self.sharding_group_size + self.mp_group_endpoints = self.sharding_group_endpoints[:] + + logging.info("Using Sharing&DP mode !") else: if self._as_outer_parallelism: @@ -516,6 +539,12 @@ def _init_comm(self): ep for idx, ep in enumerate(self.endpoints) if idx % self._inner_parallelism_size == _offset ] + + # the current entire model parallelism group is the combination of innert & sharding parallelism + self.mp_group_id = 2 + self.mp_rank = self.global_rank + self.mp_group_size = self.role_maker._worker_num() + self.mp_group_endpoints = self.endpoints[:] logging.info("Using Sharing as Outer parallelism mode !") # print( @@ -544,6 +573,13 @@ def _init_comm(self): self.sharding_rank = self.global_rank self.sharding_group_size = self.role_maker._worker_num() self.sharding_group_endpoints = self.endpoints + + # sharding parallelism is the only model parallelism in the current setting + self.mp_group_id = self.sharding_ring_id + self.mp_rank = self.sharding_rank + self.mp_group_size = self.sharding_group_size + self.mp_group_endpoints = self.sharding_group_endpoints[:] + logging.info("Using Sharing alone mode !") self.dp_ring_id = -1 @@ -552,15 +588,46 @@ def _init_comm(self): self.dp_group_endpoints = None logging.info("global word size: {}".format(self.global_word_size)) - logging.info("global rank: {}".format(self.global_rank)) + logging.info("global rank: {}".format(self.global_rank)) logging.info("sharding group_size: {}".format(self.sharding_group_size)) logging.info("sharding rank: {}".format(self.sharding_rank)) + logging.info("current model parallelism group_size: {}".format(self.mp_group_size)) + logging.info("current model parallelism rank: {}".format(self.mp_rank)) logging.info("dp group size: {}".format(self.dp_group_size)) logging.info("dp rank: {}".format(self.dp_rank)) logging.info("current endpoint: {}".format(self.current_endpoint)) + logging.info("global word endpoints: {}".format(self.endpoints)) logging.info("sharding group endpoints: {}".format( self.sharding_group_endpoints)) + logging.info("current model parallelism group endpoints: {}".format( + self.mp_group_endpoints)) logging.info("dp group endpoints: {}".format(self.dp_group_endpoints)) - logging.info("global word endpoints: {}".format(self.endpoints)) return + + def _initialization_broadcast(self, startup_prog): + """ + this funtion is to ensure the initialization between dp group to be + identical when hybrid-dp is used. + """ + block = startup_prog.global_block() + params = [] + for param in block.iter_parameters(): + params.append(param) + block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': self.dp_ring_id, + 'root': 0, + OP_ROLE_KEY: OpRole.Forward + }) + block.append_op( + type='c_sync_comm_stream', + inputs={'X': params}, + outputs={'Out': params}, + attrs={'ring_id': self.dp_ring_id, + OP_ROLE_KEY: OpRole.Forward}) + + From 5444c3b76e7feb80b36992a376bfbd6bd3c25bd5 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Thu, 4 Feb 2021 16:27:33 +0800 Subject: [PATCH 04/24] sharding-megatron suppoort gradclipbyglobalnorm --- .../sharding/gradient_clip_helper.py | 47 ++++++++++++------- .../meta_optimizers/sharding_optimizer.py | 9 ++-- 2 files changed, 34 insertions(+), 22 deletions(-) mode change 100644 => 100755 python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py old mode 100644 new mode 100755 index c6aee792fcf745..d0f8e13cdbf9cf --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -16,8 +16,8 @@ class GradientClipHelper(object): - def __init__(self, sharding_ring_id): - self.sharding_ring_id = sharding_ring_id + def __init__(self, mp_ring_id): + self.mp_ring_id = mp_ring_id def _is_gradient_clip_op(self, op): return op.desc.has_attr("op_namescope") \ @@ -31,6 +31,7 @@ def prune_gradient_clip(self, block, shard): """ deperated_vars = set() deperate_op_idx = set() + reversed_x_paramname = [] for idx, op in enumerate(block.ops): if not self._is_gradient_clip_op(op): continue @@ -44,6 +45,8 @@ def prune_gradient_clip(self, block, shard): if shard.is_param(param_name) and \ not shard.has_param(param_name): deperate_op = True + elif shard.is_param(param_name) : + reversed_x_paramname.append(param_name) if deperate_op: deperate_op_idx.add(idx) @@ -65,32 +68,42 @@ def prune_gradient_clip(self, block, shard): for input_name in op.desc.input_arg_names(): if input_name not in deperated_vars: reversed_inputs.append(input_name) + op.desc.set_input("X", reversed_inputs) assert (len(op.desc.output_arg_names()) == 1) sum_res = op.desc.output_arg_names()[0] - block._insert_op_without_sync( - idx + 1, - type='c_sync_comm_stream', - inputs={'X': sum_res}, - outputs={'Out': sum_res}, - attrs={'ring_id': 0, - OP_ROLE_KEY: OpRole.Optimize}) + + # this allreduce should not overlap with calc and should be scheduled in calc stream + # block._insert_op_without_sync( + # idx + 1, + # type='c_sync_comm_stream', + # inputs={'X': sum_res}, + # outputs={'Out': sum_res}, + # attrs={'ring_id': 0, + # OP_ROLE_KEY: OpRole.Optimize}) block._insert_op_without_sync( idx + 1, type='c_allreduce_sum', inputs={'X': sum_res}, outputs={'Out': sum_res}, attrs={ - 'ring_id': self.sharding_ring_id, - OP_ROLE_KEY: OpRole.Optimize + 'ring_id': self.mp_ring_id, + 'op_namescope': "/gradient_clip_model_parallelism", + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize, }) - block._insert_op_without_sync( - idx + 1, - type='c_sync_calc_stream', - inputs={'X': sum_res}, - outputs={'Out': sum_res}, - attrs={OP_ROLE_KEY: OpRole.Optimize}) + # block._insert_op_without_sync( + # idx + 1, + # type='c_sync_calc_stream', + # inputs={'X': sum_res}, + # outputs={'Out': sum_res}, + # attrs={OP_ROLE_KEY: OpRole.Optimize}) + # the grad sum here should take the all and only param in the current shard + to_check_param = set(reversed_x_paramname) + should_check_param = set(shard.global_params).intersection(set([param for param, worker_idx in shard.global_param2device.items() if worker_idx == shard.worker_idx])) + assert to_check_param == should_check_param, "amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(should_check_param - to_check_param, to_check_param - should_check_param) + for var_name in deperated_vars: block._remove_var(var_name, sync=False) block._sync_with_cpp() diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index b45ab35a635564..891c6c12c4c81f 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -259,13 +259,12 @@ def _prune_main_program(self, block): weightdecay_helper.prune_weight_decay(block, self._shard) # NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism # group. and each Data Parallelism group should have its own sync of FoundInfinite - FoundInfinite_ring_id = self.sharding_ring_id + Model_Paramllelism_ring_id = self.sharding_ring_id if self._as_outer_parallelism: - FoundInfinite_ring_id = self.mp_group_id + Model_Paramllelism_ring_id = self.mp_group_id FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param, - FoundInfinite_ring_id) - - gradientclip_helper = GradientClipHelper(self.sharding_ring_id) + Model_Paramllelism_ring_id) + gradientclip_helper = GradientClipHelper(Model_Paramllelism_ring_id) gradientclip_helper.prune_gradient_clip(block, self._shard) # build prog deps From 527bc962b25a149cfc668ed6ba1bb3d7997f1798 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 19 Feb 2021 16:52:14 +0800 Subject: [PATCH 05/24] Sharding allreduce --> reduce --- .../fleet/meta_optimizers/sharding/utils.py | 44 +++++++++++++++++-- .../meta_optimizers/sharding_optimizer.py | 10 +++-- 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index eb5767ec4d3436..a36f555af8779e 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -88,7 +88,7 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): grad: - 0: op that generate Var - 1: sync_calc - - 2: allreduce_sum_sharding + - 2: reduce_sum_sharding (allreduce --> reduce) - 3: sync_comm - 4: allreuce_sum_dp (dp_grads) - 5: sync_comm (dp_grads) @@ -103,7 +103,7 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): idx_gradient_clip_allreduce = -1 for idx, op in enumerate(block.ops): - if op.type == "c_allreduce_sum": + if op.type == "c_allreduce_sum" or op.type == "c_reduce_sum" : if op.all_attrs()["use_calc_stream"] == False: ring_id = op.desc.attr("ring_id") var_name = op.desc.input_arg_names()[0] @@ -137,11 +137,12 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): var_name] == 0: dp_grads_status[var_name] = 1 - elif op.type == "c_allreduce_sum": + elif op.type == "c_allreduce_sum" or op.type == "c_reduce_sum" : if op.all_attrs()["use_calc_stream"] == False: var_name = op.desc.input_arg_names()[0] ring_id = op.desc.attr("ring_id") if ring_id == sharding_ring_id: + assert op.type == "c_reduce_sum", "Grad in Sharding group should be reduce rather than allreduce" if var_name in vars_status: _status = vars_status[var_name] else: @@ -191,6 +192,8 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): raise ValueError("There should be a sync_comm op " "after allreduce the Var: {}".format( input_name)) + raise ValueError("The reduce output grad [{}] should NOT be be used in Non-root rank.".format( + input_name)) if input_name in dp_grads_status: if dp_ring_id == -1: if dp_grads_status[input_name] != 3: @@ -202,6 +205,7 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): raise ValueError( "The grad in shard should be allreduce and sync" "twice before usage {}".format(input_name)) + for output_name in op.desc.output_arg_names(): if output_name in vars_status and \ @@ -334,6 +338,25 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars): return +def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard): + """ + _add_allreduce_ops + """ + for var in reduce_vars: + root_id = get_grad_device(var, shard) + assert root_id >= 0, "root id should be a positive int".format(var) + block._insert_op_without_sync( + insert_idx, + type='c_reduce_sum', + inputs={'X': var}, + outputs={'Out': var}, + attrs={'ring_id': ring_id, + 'root_id': root_id, + OP_ROLE_KEY: OpRole.Backward}) + + return + + def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): """ @@ -511,3 +534,18 @@ def sharding_predicate(var): filename=None) return + +def get_grad_device(grad_name, shard): + assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(grad_name) + base_name = None + # mind the traversal order + possible_suffixes = ['.cast_fp16@GRAD', '@GRAD'] + for suffix in possible_suffixes: + if suffix in grad_name : + base_name = re.sub(suffix, '', grad_name) + break + + assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(base_name) + + return shard.global_param2device[base_name] + diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 891c6c12c4c81f..b63f2618f3a00a 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -351,9 +351,10 @@ def _add_broadcast_allreduce(self, block): insert_sync_comm_ops(block, self._segments[-1]._end_idx, self.sharding_ring_id, self._segments[-1]._allreduce_vars) - insert_allreduce_ops(block, self._segments[-1]._end_idx, + # allreduce --> reduce + insert_reduce_ops(block, self._segments[-1]._end_idx, self.sharding_ring_id, - self._segments[-1]._allreduce_vars) + self._segments[-1]._allreduce_vars, self._shard) for idx, segment in reversed(list(enumerate(self._segments))): allreduce_vars = self._segments[ @@ -432,8 +433,9 @@ def _add_broadcast_allreduce(self, block): insert_sync_comm_ops(block, segment._start_idx, self.sharding_ring_id, allreduce_vars) # sharding - insert_allreduce_ops(block, segment._start_idx, - self.sharding_ring_id, allreduce_vars) + # allreduce --> reduce + insert_reduce_ops(block, segment._start_idx, + self.sharding_ring_id, allreduce_vars, self._shard) block._sync_with_cpp() From 7e35d31d9e40272ec90d459497d43a959f1a1903 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 5 Mar 2021 11:15:28 +0800 Subject: [PATCH 06/24] sharding optimize init speed --- .../meta_optimizers/sharding_optimizer.py | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index b63f2618f3a00a..761fa93d5cd7dd 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -25,6 +25,9 @@ from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps from paddle.distributed.fleet.meta_optimizers.sharding.utils import * import logging +logging.basicConfig( + format='%(asctime)s %(levelname)-8s %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') from functools import reduce __all__ = ["ShardingOptimizer"] @@ -78,6 +81,7 @@ def minimize_impl(self, no_grad_set=None): # TODO: (JZ-LIANG) support multiple comm in future # self._nrings = self.user_defined_strategy.nccl_comm_num + logging.info("########### into sharding minimize: wait at end") self._nrings_sharding = 1 self._nrings_dp = 1 self._fuse_broadcast_MB = self.user_defined_strategy.sharding_configs[ @@ -95,6 +99,7 @@ def minimize_impl(self, "self.inner_opt of ShardingOptimizer should not be None.") optimize_ops, params_grads = self.inner_opt.minimize( loss, startup_program, parameter_list, no_grad_set) + logging.info("########### after inner opt") if startup_program is None: startup_program = default_startup_program() @@ -105,14 +110,17 @@ def minimize_impl(self, # step1: set_up self._set_up(params_grads) + logging.info("########### after inner _set_up") # step2: split_program self._split_program(main_block) + logging.info("########### after inner _split_program") # step3: add broadcast and reduce ops self._add_broadcast_allreduce(main_block) main_block._sync_with_cpp() startup_block._sync_with_cpp() + logging.info("########### after inner _add_broadcast_allreduce") # step4: insert reduce_sum for grad grad_scale_coeff = self.role_maker._worker_num() @@ -126,12 +134,14 @@ def minimize_impl(self, self._prune_startup_program(startup_block) if self.hybrid_dp: self._initialization_broadcast(startup_program) - + logging.info("########### after inner prune") # check op dependecy check_broadcast(main_block) check_allreduce_sum(main_block, self._shard, self.sharding_ring_id, self.dp_ring_id) + logging.info("########### after inner check") self._wait() + logging.info("########### after inner wait") return optimize_ops, params_grads def _set_up(self, params_grads): @@ -148,20 +158,20 @@ def _set_up(self, params_grads): self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, self.sharding_group_endpoints, self.sharding_rank, - self.sharding_ring_id, True) + self.sharding_ring_id, False) # inner & outer model parallelism if self._as_outer_parallelism: self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, self.mp_group_endpoints, self.mp_rank, - self.mp_group_id, True) + self.mp_group_id, False) # dp if self.hybrid_dp: self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, - self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, True) + self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, False) startup_block = self._startup_program.global_block() startup_block._sync_with_cpp() @@ -176,9 +186,10 @@ def _set_up(self, params_grads): self._main_program.global_block()) def _wait(self, ): - endpoints = self.role_maker._get_trainer_endpoints() + # only the first parallelsm group that init nccl need to be wait. + endpoints = self.sharding_group_endpoints[:] current_endpoint = endpoints[self.role_maker._worker_index()] - if self.role_maker._worker_index() == 0: + if self.sharding_rank == 0: self._collective_helper._wait(current_endpoint, endpoints) def _split_program(self, block): From 2cc1b7c4eae48d682dff71327ae1f2e600111e3f Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 10 Mar 2021 17:46:09 +0800 Subject: [PATCH 07/24] recompute remove useless log --- python/paddle/fluid/backward.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 0b4fa1469d77e3..bd20b6d31a4f70 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -829,8 +829,6 @@ def _append_backward_ops_with_checkpoints_( cross_vars = set(vars_should_be_hold) - set(checkpoints_name) _logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ len(cross_vars), cross_vars)) - _logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ - len(cross_vars), cross_vars)) # b. output of seed op should be kept in memory vars_should_be_hold.extend(program_stat.get_reserved_vars()) From f273420bdb5b1b88b5e0f1c594ff214f02f729bb Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Thu, 11 Mar 2021 17:41:47 +0800 Subject: [PATCH 08/24] sharding: segment strategy --- .../framework/distributed_strategy.proto | 4 +- .../meta_optimizers/sharding_optimizer.py | 92 +++++++++++++++++-- 2 files changed, 85 insertions(+), 11 deletions(-) mode change 100644 => 100755 paddle/fluid/framework/distributed_strategy.proto diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto old mode 100644 new mode 100755 index e735da6501b25a..a3608ca4476bfb --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -29,11 +29,13 @@ message RecomputeConfig { } message ShardingConfig { - optional float fuse_broadcast_MB = 1 [ default = 32.0 ]; + optional float broadcast_MB = 1 [ default = 32.0 ]; optional bool hybrid_dp = 2 [ default = false ]; optional int32 sharding_group_size = 3 [ default = 8 ]; optional bool as_outer_parallelism = 4 [ default = false ]; optional int32 inner_parallelism_size = 5 [ default = 8 ]; + optional string sharding_segment_strategy = 6 [ default = 'broadcast_size' ]; + repeated string sharding_segment_anchors = 7; } message AMPConfig { diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 761fa93d5cd7dd..0b6f7d55b111ab 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -14,7 +14,7 @@ from paddle.fluid import unique_name, core import paddle.fluid as fluid - +import paddle.distributed.fleet as fleet from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper from paddle.distributed.fleet.meta_optimizers.common import is_backward_op from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase @@ -72,7 +72,7 @@ def _disable_strategy(self, dist_strategy): def _enable_strategy(self, dist_strategy, context): dist_strategy.sharding = True - dist_strategy.sharding_configs = {"fuse_broadcast_MB": 32} + dist_strategy.sharding_configs = {"broadcast_MB": 32} def minimize_impl(self, loss, @@ -84,8 +84,6 @@ def minimize_impl(self, logging.info("########### into sharding minimize: wait at end") self._nrings_sharding = 1 self._nrings_dp = 1 - self._fuse_broadcast_MB = self.user_defined_strategy.sharding_configs[ - "fuse_broadcast_MB"] self.hybrid_dp = self.user_defined_strategy.sharding_configs[ "hybrid_dp"] self._as_outer_parallelism = self.user_defined_strategy.sharding_configs[ @@ -93,6 +91,23 @@ def minimize_impl(self, self._inner_parallelism_size = int( self.user_defined_strategy.sharding_configs[ "inner_parallelism_size"]) + self._sharding_segment_strategy = str(self.user_defined_strategy.sharding_configs[ + "sharding_segment_strategy"]) + + if self._sharding_segment_strategy == "broadcast_size": + self._broadcast_MB = int(self.user_defined_strategy.sharding_configs[ + "broadcast_MB"]) + assert self._broadcast_MB > 0, "segment size should larger than zero !" + elif self._sharding_segment_strategy == "anchors": + self._sharding_segment_anchors = self.user_defined_strategy.sharding_configs[ + "sharding_segment_anchors"] + assert len(self._sharding_segment_anchors) > 0, "you should set the sharding segment anchors !" + self._backward_remain_anchors = self._sharding_segment_anchors[:] + self._forward_remain_anchors = [] + else: + raise NotImplementedError( + "the sharding segment strategy [{}] is not implemented".format(str(self._sharding_segment_strategy))) + if self.inner_opt is None: raise ValueError( @@ -188,25 +203,60 @@ def _set_up(self, params_grads): def _wait(self, ): # only the first parallelsm group that init nccl need to be wait. endpoints = self.sharding_group_endpoints[:] - current_endpoint = endpoints[self.role_maker._worker_index()] + current_endpoint = endpoints[self.sharding_rank] if self.sharding_rank == 0: self._collective_helper._wait(current_endpoint, endpoints) + def collect_segment(self, segment, op_idx, block): + segment._start_idx = op_idx + 1 + self._segments.insert(0, segment) + new_segment = ProgramSegment(block) + new_segment._end_idx = op_idx + 1 + + return new_segment + def _split_program(self, block): for op_idx, op in reversed(list(enumerate(block.ops))): if int(op.attr('op_role')) != int(OpRole.Optimize): last_backward_op_idx = op_idx + 1 break + + var2broadcast_time = dict() segment = ProgramSegment(block) segment._end_idx = last_backward_op_idx for op_idx in reversed(range(last_backward_op_idx)): op = block.ops[op_idx] assert (int(op.attr('op_role')) != int(OpRole.Optimize)) - if segment._param_mem >= self._fuse_broadcast_MB: - segment._start_idx = op_idx + 1 - self._segments.insert(0, segment) - segment = ProgramSegment(block) - segment._end_idx = op_idx + 1 + if self._sharding_segment_strategy == "broadcast_size": + if segment._param_mem >= self._broadcast_MB: + segment = self.collect_segment(segment, op_idx, block) + + elif self._sharding_segment_strategy == "anchors": + if int(op.attr('op_role')) == int(OpRole.Backward): + for input_name in op.desc.input_arg_names(): + + # NOTE (JZ-LIANG) naive rule to support amp, if amp change, should modify here accordingly + if 'AMPOptimizer' in fleet._get_applied_meta_list(): + if ".cast_fp16@GRAD" not in input_name: + continue + else: + input_name = input_name[:input_name.find(".cast_fp16@GRAD")] + + if input_name in self._backward_remain_anchors: + logging.info("backward segment:") + logging.info("op [{}] input [{}] output [{}]".format(op.desc.type() ,op.desc.input_arg_names(), op.desc.output_arg_names())) + segment = self.collect_segment(segment, op_idx, block) + assert input_name not in self._forward_remain_anchors, "segment anchor [{}] met twice !".format(input_name) + self._backward_remain_anchors.remove(input_name) + self._forward_remain_anchors.append(input_name) + elif int(op.attr('op_role')) == int(OpRole.Forward): + for output_name in op.desc.output_arg_names(): + if output_name in self._forward_remain_anchors: + logging.info("forward segment:") + logging.info("op [{}] input [{}] output [{}]".format(op.desc.type() ,op.desc.input_arg_names(), op.desc.output_arg_names())) + segment = self.collect_segment(segment, op_idx, block) + self._forward_remain_anchors.remove(output_name) + # find broadcast vars for input_name in op.desc.input_arg_names(): @@ -224,6 +274,15 @@ def _split_program(self, block): broadcast_var_name = unique_name.generate(input_name + "@BroadCast") segment._fill_constant_vars.append(broadcast_var_name) + + # (JZ-LIANG) should use Param base name ? + broadcast_var_base_name = input_name + if "subprog" in broadcast_var_base_name: + # remove suffix + broadcast_var_base_name = broadcast_var_base_name[:broadcast_var_base_name.find(".subprog")] + + var2broadcast_time[broadcast_var_base_name] = var2broadcast_time.get(broadcast_var_base_name, 0) + 1 + segment._param2broadcast[input_name] = broadcast_var_name segment._broadcast_vars.append((broadcast_var_name, self._shard.device(input_name))) @@ -253,6 +312,19 @@ def _split_program(self, block): if segment._param_mem > 0: segment._start_idx = 0 self._segments.insert(0, segment) + + if self._sharding_segment_strategy == "anchors": + assert len(self._forward_remain_anchors) == 0 + assert len(self._backward_remain_anchors) == 0 + + for varname in sorted(var2broadcast_time, key=var2broadcast_time.get, reverse=True): + logging.info("Sharding broadcast: [{}] times [{}]".format(var2broadcast_time[varname], varname)) + for idx_ in range(len(self._segments)): + logging.info("segment [{}] :".format(idx_)) + logging.info("start op: [{}] [{}]".format(block.ops[self._segments[idx_]._start_idx].desc.type(), + block.ops[self._segments[idx_]._start_idx].desc.input_arg_names())) + logging.info("end op: [{}] [{}]".format(block.ops[self._segments[idx_]._end_idx].desc.type(), + block.ops[self._segments[idx_]._end_idx].desc.input_arg_names())) return def _prune_main_program(self, block): From 6a18b38527eb52080e2099880fd3886eeb7ea19b Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Mon, 15 Mar 2021 16:47:57 +0800 Subject: [PATCH 09/24] temp change for ernie_10b_two_branch --- .../fleet/meta_optimizers/sharding_optimizer.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 0b6f7d55b111ab..6ba7a9e558c753 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -237,11 +237,15 @@ def _split_program(self, block): # NOTE (JZ-LIANG) naive rule to support amp, if amp change, should modify here accordingly if 'AMPOptimizer' in fleet._get_applied_meta_list(): - if ".cast_fp16@GRAD" not in input_name: + # if ".cast_fp16@GRAD" not in input_name: + # continue + # else: + # input_name = input_name[:input_name.find(".cast_fp16@GRAD")] + if (op.type != "cast" and op.type != "layer_norm") or "@Fetch_0" not in input_name: continue else: - input_name = input_name[:input_name.find(".cast_fp16@GRAD")] - + input_name = input_name[:input_name.find("@Fetch_0")] + if input_name in self._backward_remain_anchors: logging.info("backward segment:") logging.info("op [{}] input [{}] output [{}]".format(op.desc.type() ,op.desc.input_arg_names(), op.desc.output_arg_names())) @@ -314,8 +318,8 @@ def _split_program(self, block): self._segments.insert(0, segment) if self._sharding_segment_strategy == "anchors": - assert len(self._forward_remain_anchors) == 0 - assert len(self._backward_remain_anchors) == 0 + assert len(self._forward_remain_anchors) == 0, "remain anchors {}".format(self._forward_remain_anchors) + assert len(self._backward_remain_anchors) == 0, "remain anchors {}".format(self._backward_remain_anchors) for varname in sorted(var2broadcast_time, key=var2broadcast_time.get, reverse=True): logging.info("Sharding broadcast: [{}] times [{}]".format(var2broadcast_time[varname], varname)) From 98baf20db0da0ce0e7ea31e5b46cd76ba52103bc Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 16 Mar 2021 19:21:41 +0800 Subject: [PATCH 10/24] sharding: gradient merge --- .../framework/distributed_strategy.proto | 1 + .../fleet/meta_optimizers/sharding/utils.py | 1 - .../meta_optimizers/sharding_optimizer.py | 329 ++++++++++++++++-- 3 files changed, 305 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index a3608ca4476bfb..67923cf6cf441e 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -36,6 +36,7 @@ message ShardingConfig { optional int32 inner_parallelism_size = 5 [ default = 8 ]; optional string sharding_segment_strategy = 6 [ default = 'broadcast_size' ]; repeated string sharding_segment_anchors = 7; + optional int32 gradient_merge_acc_step = 8 [ default = 1 ]; } message AMPConfig { diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index a36f555af8779e..e553ad0d229409 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -548,4 +548,3 @@ def get_grad_device(grad_name, shard): assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(base_name) return shard.global_param2device[base_name] - diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 6ba7a9e558c753..bff1015645838c 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle from paddle.fluid import unique_name, core import paddle.fluid as fluid import paddle.distributed.fleet as fleet @@ -24,6 +25,10 @@ from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps from paddle.distributed.fleet.meta_optimizers.sharding.utils import * +from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard + +from paddle.fluid import layers + import logging logging.basicConfig( format='%(asctime)s %(levelname)-8s %(message)s', @@ -107,8 +112,10 @@ def minimize_impl(self, else: raise NotImplementedError( "the sharding segment strategy [{}] is not implemented".format(str(self._sharding_segment_strategy))) - - + self._gradient_merge_acc_step = int(self.user_defined_strategy.sharding_configs[ + "gradient_merge_acc_step"]) + self._grad2merged_grad = dict() + if self.inner_opt is None: raise ValueError( "self.inner_opt of ShardingOptimizer should not be None.") @@ -150,13 +157,26 @@ def minimize_impl(self, if self.hybrid_dp: self._initialization_broadcast(startup_program) logging.info("########### after inner prune") - # check op dependecy - check_broadcast(main_block) - check_allreduce_sum(main_block, self._shard, self.sharding_ring_id, - self.dp_ring_id) + + with open("main_before_gm_%d" % self.role_maker._worker_index(), 'w') as f: + f.writelines(str(self._main_program)) + + # step6: optional gradient merge + if self._gradient_merge_acc_step > 1: + self._sharding_gradient_merge(main_block) + + with open("main_after_gm_%d" % self.role_maker._worker_index(), 'w') as f: + f.writelines(str(self._main_program)) + + # # check op dependecy + # check_broadcast(main_block) + # check_allreduce_sum(main_block, self._shard, self.sharding_ring_id, + # self.dp_ring_id) logging.info("########### after inner check") self._wait() logging.info("########### after inner wait") + + return optimize_ops, params_grads def _set_up(self, params_grads): @@ -422,7 +442,8 @@ def _prune_main_program(self, block): def _add_broadcast_allreduce(self, block): """ - _add_broadcast_allreduce + add broadcast allreduce op + if enable gradient_merge, insert related ops """ if len(self._segments) < 1: return @@ -430,11 +451,21 @@ def _add_broadcast_allreduce(self, block): if self._segments[-1]._allreduce_vars: shard_allredue_vars = self._shard.filter_grads(self._segments[-1] ._allreduce_vars) - if self.hybrid_dp and len(shard_allredue_vars) >= 1: - insert_sync_comm_ops(block, self._segments[-1]._end_idx, - self.dp_ring_id, shard_allredue_vars) - insert_allreduce_ops(block, self._segments[-1]._end_idx, - self.dp_ring_id, shard_allredue_vars) + if self._gradient_merge_acc_step <= 1: + if self.hybrid_dp and len(shard_allredue_vars) >= 1: + insert_sync_comm_ops(block, self._segments[-1]._end_idx, + self.dp_ring_id, shard_allredue_vars) + insert_allreduce_ops(block, self._segments[-1]._end_idx, + self.dp_ring_id, shard_allredue_vars) + # gradient merge + else: + self.create_persistable_gradients_and_insert_merge_ops( + block, + self._startup_program.global_block(), + self._segments[-1]._end_idx, + shard_allredue_vars, + self._shard) + insert_sync_comm_ops(block, self._segments[-1]._end_idx, self.sharding_ring_id, self._segments[-1]._allreduce_vars) @@ -480,19 +511,27 @@ def _add_broadcast_allreduce(self, block): # step2: add Sync ops shard_allredue_vars = self._shard.filter_grads(allreduce_vars) - if self.hybrid_dp and len(shard_allredue_vars) >= 1: - insert_sync_comm_ops(block, segment._end_idx, self.dp_ring_id, - shard_allredue_vars) + if self._gradient_merge_acc_step <= 1: + if self.hybrid_dp and len(shard_allredue_vars) >= 1: + insert_sync_comm_ops(block, segment._end_idx, self.dp_ring_id, + shard_allredue_vars) + + broad_cast_vars = [x[0] for x in broadcast_vars] + if len(broad_cast_vars) > 0: + insert_sync_comm_ops(block, segment._end_idx, + self.sharding_ring_id, broad_cast_vars) + else: + comm_dep_vars = allreduce_vars + [x[0] for x in broadcast_vars] + if len(comm_dep_vars) > 0: + insert_sync_comm_ops(block, segment._end_idx, + self.sharding_ring_id, comm_dep_vars) + # gradient merge + else: broad_cast_vars = [x[0] for x in broadcast_vars] if len(broad_cast_vars) > 0: insert_sync_comm_ops(block, segment._end_idx, - self.sharding_ring_id, broad_cast_vars) - else: - comm_dep_vars = allreduce_vars + [x[0] for x in broadcast_vars] - if len(comm_dep_vars) > 0: - insert_sync_comm_ops(block, segment._end_idx, - self.sharding_ring_id, comm_dep_vars) + self.sharding_ring_id, broad_cast_vars) calc_dep_vars = fill_constant_vars + [ k for k, v in cast_ops.items() @@ -510,15 +549,30 @@ def _add_broadcast_allreduce(self, block): insert_cast_ops(block, segment._end_idx, cast_ops) # step5: add broadcast ops + # gradient merge + if self._gradient_merge_acc_step > 1: + self.create_persistable_gradients_and_insert_merge_ops( + block, + self._startup_program.global_block(), + segment._start_idx, + shard_allredue_vars, + self._shard) + insert_broadcast_ops(block, segment._start_idx, self.sharding_ring_id, broadcast_vars) + # step6: add all_reduce ops # dp - if self.hybrid_dp and len(shard_allredue_vars) >= 1: - insert_allreduce_ops(block, segment._start_idx, self.dp_ring_id, - shard_allredue_vars) + if self._gradient_merge_acc_step <= 1: + if self.hybrid_dp and len(shard_allredue_vars) >= 1: + insert_allreduce_ops(block, segment._start_idx, self.dp_ring_id, + shard_allredue_vars) + insert_sync_comm_ops(block, segment._start_idx, + self.sharding_ring_id, allreduce_vars) + # gradient merge + else: insert_sync_comm_ops(block, segment._start_idx, - self.sharding_ring_id, allreduce_vars) + self.sharding_ring_id, allreduce_vars) # sharding # allreduce --> reduce insert_reduce_ops(block, segment._start_idx, @@ -718,4 +772,229 @@ def _initialization_broadcast(self, startup_prog): attrs={'ring_id': self.dp_ring_id, OP_ROLE_KEY: OpRole.Forward}) + # sharding gradient merge + def create_persistable_gradients_and_insert_merge_ops(self, main_block, startup_block, insert_idx, grad_names, shard): + + for grad_name in grad_names: + assert get_grad_device(grad_name, shard) == shard.worker_idx, "try to merge gradient not belong to current shard: [{}]".format(grad_name) + persistable_grad_name = grad_name + '@GradiantMerge' + assert grad_name not in self._grad2merged_grad, "grad [{}] already in grad2merged_grad, maybe you meet sharing weight case !".format(grad_name) + self._grad2merged_grad[grad_name] = persistable_grad_name + grad_var = main_block.var(grad_name) + # create var + gradient_merge_var = main_block.create_var( + name=persistable_grad_name, + shape=grad_var.shape, + dtype=grad_var.dtype, + persistable=True) + startup_gradient_merge_var = startup_block.create_var( + name=persistable_grad_name, + shape=grad_var.shape, + dtype=grad_var.dtype, + persistable=True) + + # merge gradient + main_block._insert_op_without_sync( + insert_idx, + type="elementwise_add", + inputs={'X': grad_name, + 'Y': gradient_merge_var}, + outputs={'Out': gradient_merge_var}, + attrs={'axis': -1, + 'use_mkldnn': False, + OP_ROLE_KEY: OpRole.Backward}) + + # startup initialization + startup_block.append_op( + type="fill_constant", + outputs={"Out": startup_gradient_merge_var}, + attrs={ + "shape": grad_var.shape, + "dtype": grad_var.dtype, + "value": float(0), + }) + + main_block._sync_with_cpp() + startup_block._sync_with_cpp() + + + def _create_gm_cond(self, main_block): + # Add const var + acc_step_var = layers.create_global_var( + name="gradient_merge_acc_step", + shape=[1], + value=int(self._gradient_merge_acc_step), + dtype='int32', + persistable=True, + force_cpu=True) + + zero_var = layers.create_global_var( + name="gradient_merge_zero", + shape=[1], + value=int(0), + dtype='int32', + persistable=True, + force_cpu=True) + + # Add step var & cond var + current_step_var = layers.create_global_var( + name="gradient_merge_current_step", + shape=[1], + value=int(0), + dtype='int32', + persistable=True, + force_cpu=True) + + cond_var = layers.create_global_var( + name="gradient_merge_cond", + shape=[1], + value=bool(0), + dtype='bool', + persistable=True, + force_cpu=True) + + with device_guard("cpu"): + # step_var = (step_var + 1) % k_step + main_block.append_op( + type='increment', + inputs={'X': [current_step_var]}, + outputs={'Out': [current_step_var]}, + attrs={'step': float(1), + OP_ROLE_KEY: OpRole.Optimize}) + + main_block.append_op( + type='elementwise_mod', + inputs={'X': current_step_var, + 'Y': acc_step_var}, + outputs={'Out': current_step_var}, + attrs={'axis': -1, + OP_ROLE_KEY: OpRole.Optimize, + 'use_mkldnn': False}) + + # cond_var = (step_var == 0) + main_block.append_op( + type='equal', + inputs={'X': current_step_var, + 'Y': zero_var}, + outputs={'Out': cond_var}, + attrs={OP_ROLE_KEY: OpRole.Optimize} + ) + paddle.static.Print(current_step_var, message="in FWBW last conditional") + return cond_var + + def _true_apply_gradient(self): + """ + allreduce grad@gradientmerge in dp group + grad@gradientmerge / acc_step + re-create all optimize ops of origin main block and rename them + cast(backward) + amp + clip + opt + # fill constant grad@gradientmerge + + """ + # current conditional block + main_block = self._main_program.global_block() + cur_block_idx = self._main_program.current_block_idx + cur_block = self._main_program.current_block() + + # cur_block's forward_block & backward_block is itself + cur_block._set_forward_block_idx(cur_block_idx) + + # allreduce grad@gradientmerge + if self.hybrid_dp: + assert self.dp_ring_id >= 0, "dp_ring_id should larger than 0 when in sharding&DP mode" + for grad, merged_grad in self._grad2merged_grad.items(): + merged_grad_var = main_block.var(merged_grad) + cur_block.append_op( + type='c_allreduce_sum', + inputs={'X': merged_grad_var}, + outputs={'Out': merged_grad_var}, + attrs={'ring_id': self.dp_ring_id, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize}) + + # grad@gradientmerge / acc_step + for grad, merged_grad in self._grad2merged_grad.items(): + # grad /= k_steps + merged_grad_var = main_block.var(merged_grad) + cur_block.append_op( + type='scale', + inputs={'X': merged_grad_var}, + outputs={'Out': merged_grad_var}, + attrs={ + 'scale': 1.0 / float(self._gradient_merge_acc_step), + 'bias': 0.0, + 'bias_after_scale': False, + OP_ROLE_KEY: OpRole.Optimize + }) + + # re-create optimize ops + for op_desc in self.original_optimize_ops_desc: + new_op_desc = cur_block.desc.append_op() + new_op_desc.copy_from(op_desc) + + for input_name in new_op_desc.input_arg_names(): + if input_name in self._grad2merged_grad: + new_op_desc._rename_input(input_name, self._grad2merged_grad[input_name]) + + for output_name in new_op_desc.output_arg_names(): + if output_name in self._grad2merged_grad: + new_op_desc._rename_input(output_name, self._grad2merged_grad[output_name]) + + cur_block._sync_with_cpp() + # fill zero to grad@gradientmerge + for grad, merged_grad in self._grad2merged_grad.items(): + merged_grad_var = main_block.var(merged_grad) + cur_block.append_op( + type='fill_constant', + outputs={'Out': merged_grad_var}, + attrs={ + "shape": merged_grad_var.shape, + "dtype": merged_grad_var.dtype, + "value": float(0), + OP_ROLE_KEY: OpRole.Optimize + }) + + lr_var = main_block.var("@LR_DECAY_COUNTER@") + paddle.static.Print(lr_var, message="in OPTIMIZE last conditional") + + + def _sharding_gradient_merge(self, main_block): + + """ + copy all optimize ops in origin main block + remove all optimize ops in origin main block + create cond block + + """ + # copy original optimize ops to temp ops desc list + # remove them from block 0 + tmp_copy_block = self._main_program._create_block() + + self.original_optimize_ops_desc = [] + for op_idx, op in reversed(list(enumerate(main_block.ops))): + if int(op.attr('op_role')) != int(OpRole.Optimize): + continue + else: + tmp_op_desc = tmp_copy_block.desc.append_op() + tmp_op_desc.copy_from(op.desc) + self.original_optimize_ops_desc.append(tmp_op_desc) + main_block._remove_op(op_idx, sync=False) + tmp_copy_block._sync_with_cpp() + self.original_optimize_ops_desc = list(reversed(self.original_optimize_ops_desc)) + print("original_optimize_ops_desc :") + for desc in self.original_optimize_ops_desc: + print(desc.type(), desc.input_arg_names()) + + # back to block 0 + self._main_program._rollback() + + # create cond vars and ops at the end of block 0 + cond = self._create_gm_cond(main_block) + + # create cond block + layers.cond(cond, true_fn=self._true_apply_gradient, false_fn=None) + \ No newline at end of file From 9ece14f356b15490649187287c46f69e1694e4b1 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 19 Mar 2021 11:41:22 +0800 Subject: [PATCH 11/24] sharding gradient merge: fix OOM --- .../meta_optimizers/sharding_optimizer.py | 68 ++++++++++++++++++- 1 file changed, 65 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index bff1015645838c..536cbdc5115415 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -851,7 +851,7 @@ def _create_gm_cond(self, main_block): shape=[1], value=bool(0), dtype='bool', - persistable=True, + persistable=False, force_cpu=True) with device_guard("cpu"): @@ -899,6 +899,7 @@ def _true_apply_gradient(self): main_block = self._main_program.global_block() cur_block_idx = self._main_program.current_block_idx cur_block = self._main_program.current_block() + self.cond_block = self._main_program.current_block() # cur_block's forward_block & backward_block is itself cur_block._set_forward_block_idx(cur_block_idx) @@ -915,6 +916,7 @@ def _true_apply_gradient(self): attrs={'ring_id': self.dp_ring_id, 'use_calc_stream': True, OP_ROLE_KEY: OpRole.Optimize}) + print("after allreduce grad@gradientmerge ") # grad@gradientmerge / acc_step for grad, merged_grad in self._grad2merged_grad.items(): @@ -930,21 +932,46 @@ def _true_apply_gradient(self): 'bias_after_scale': False, OP_ROLE_KEY: OpRole.Optimize }) + print("after allreduce grad@gradientmerge / acc_step") # re-create optimize ops + already_moved_var_names = [] for op_desc in self.original_optimize_ops_desc: new_op_desc = cur_block.desc.append_op() new_op_desc.copy_from(op_desc) + if op_desc.type in ["check_finite_and_unscale", "update_loss_scaling"]: + print("check @@@@@: ", op_desc.type) + print("input @@@@@: ", op_desc.output_arg_names) + for input_name in new_op_desc.input_arg_names(): if input_name in self._grad2merged_grad: new_op_desc._rename_input(input_name, self._grad2merged_grad[input_name]) for output_name in new_op_desc.output_arg_names(): if output_name in self._grad2merged_grad: - new_op_desc._rename_input(output_name, self._grad2merged_grad[output_name]) + new_op_desc._rename_output(output_name, self._grad2merged_grad[output_name]) + # move non temp optimize vars from block0 to cond block + if output_name not in already_moved_var_names and output_name not in self._grad2merged_grad.keys(): + var_ = self._main_program.global_block().var(output_name) + if not var_.persistable: + print("remove: ", output_name) + # move + name_ = var_.name + shape_ = var_.shape + type_ = var_.dtype + self._main_program.global_block()._remove_var(var_.name, sync=False) + self.cond_block.create_var( + name=name_, + shape=shape_, + dtype=type_, + persistable=False) + already_moved_var_names.append(name_) + + self._main_program.global_block()._sync_with_cpp() cur_block._sync_with_cpp() + # fill zero to grad@gradientmerge for grad, merged_grad in self._grad2merged_grad.items(): merged_grad_var = main_block.var(merged_grad) @@ -991,10 +1018,45 @@ def _sharding_gradient_merge(self, main_block): # back to block 0 self._main_program._rollback() + print("after first rollback") # create cond vars and ops at the end of block 0 cond = self._create_gm_cond(main_block) + print("after _create_gm_cond") + # create cond block - layers.cond(cond, true_fn=self._true_apply_gradient, false_fn=None) + cond_block = self._main_program._create_block() + self._true_apply_gradient() + print("after _true_apply_gradient") + + # back to block 0 + self._main_program._rollback() + + # cond op + step_scope = self._main_program.global_block().create_var( + type=core.VarDesc.VarType.STEP_SCOPES) + conditional_block_op = self._main_program.global_block().append_op( + type='conditional_block', + inputs={ + 'Cond': cond, + 'Input': [], + }, + outputs={'Out': [], + 'Scope': [step_scope]}, + attrs={ + 'sub_block': cond_block, + 'is_scalar_condition': True, + }) + + + + + + + + + + + \ No newline at end of file From add91b77267ab0447b954715f42ea2c54680ea0a Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 19 Mar 2021 19:47:21 +0800 Subject: [PATCH 12/24] sharding: revise save logic for gradient merge --- .../distributed/fleet/meta_optimizers/sharding/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index e553ad0d229409..0ee43e6157725d 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -514,13 +514,18 @@ def is_opt_vars(var): if var.name.endswith(check): return True return False + + def is_gradient_merge_vars(var): + # NOTE(liangjianzhong): to revise save/load logic in framework instead of write this naive rule + + return var.name.endswith("@GradiantMerge") def is_trainable(var): return isinstance(var, paddle.fluid.framework.Parameter) and var.trainable def sharding_predicate(var): - return is_trainable(var) or is_opt_vars(var) + return is_trainable(var) or is_opt_vars(var) or is_gradient_merge_vars(var) if int(os.environ.get('PADDLE_TRAINER_ID', 0)) == 0: paddle.fluid.io.save_persistables( From e01e22a8c6d9af54545658d84e76b5e24fa1b6ba Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 26 Mar 2021 16:47:30 +0800 Subject: [PATCH 13/24] Sharding: revise code format --- .../distributed/fleet/meta_optimizers/sharding_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 536cbdc5115415..31178b74c17753 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -15,7 +15,7 @@ import paddle from paddle.fluid import unique_name, core import paddle.fluid as fluid -import paddle.distributed.fleet as fleet +from paddle.distributed import fleet from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper from paddle.distributed.fleet.meta_optimizers.common import is_backward_op from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase From ffb492b7367397c25606e7bd716a01e4e1305eea Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 26 Mar 2021 20:56:37 +0800 Subject: [PATCH 14/24] sharding: update anchor segment strategy --- .../fleet/meta_optimizers/sharding_optimizer.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 31178b74c17753..0b004b9816d1fd 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -256,16 +256,12 @@ def _split_program(self, block): for input_name in op.desc.input_arg_names(): # NOTE (JZ-LIANG) naive rule to support amp, if amp change, should modify here accordingly - if 'AMPOptimizer' in fleet._get_applied_meta_list(): - # if ".cast_fp16@GRAD" not in input_name: - # continue - # else: - # input_name = input_name[:input_name.find(".cast_fp16@GRAD")] - if (op.type != "cast" and op.type != "layer_norm") or "@Fetch_0" not in input_name: + if self.user_defined_strategy.amp: + if ".cast_fp16@GRAD" not in input_name: continue else: - input_name = input_name[:input_name.find("@Fetch_0")] - + input_name = input_name[:input_name.find(".cast_fp16@GRAD")] + if input_name in self._backward_remain_anchors: logging.info("backward segment:") logging.info("op [{}] input [{}] output [{}]".format(op.desc.type() ,op.desc.input_arg_names(), op.desc.output_arg_names())) From 0abe6e905a9a2e22417520c6f10712e693699ab2 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 26 Mar 2021 21:53:50 +0800 Subject: [PATCH 15/24] sharding: revise anchor segment logic --- .../fleet/meta_optimizers/amp_optimizer.py | 3 +- .../meta_optimizers/sharding_optimizer.py | 284 +++++++++--------- 2 files changed, 141 insertions(+), 146 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py index 858a28e6773f11..8e4ddedadf0682 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py @@ -57,7 +57,8 @@ def _init_wrapped_opt(self): # add is_distributed to optimize amp, overlap communication and # computation by split the check_finite_and_unscale op. is_distributed = self.role_maker._worker_num() > 1 - if self.user_defined_strategy.sharding or self.user_defined_strategy.model_parallel: + if self.user_defined_strategy.sharding: + # if self.user_defined_strategy.sharding or self.user_defined_strategy.model_parallel: # FIXME(wangxi). sharding failed when split check_finite_and_unscale # FIXME(JZ-LIANG). To support Sharding-Megatron-AMP, Megatron should follow Sharding's behavior is_distributed = False diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 0b004b9816d1fd..5864c060ed0778 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -15,7 +15,6 @@ import paddle from paddle.fluid import unique_name, core import paddle.fluid as fluid -from paddle.distributed import fleet from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper from paddle.distributed.fleet.meta_optimizers.common import is_backward_op from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase @@ -96,26 +95,30 @@ def minimize_impl(self, self._inner_parallelism_size = int( self.user_defined_strategy.sharding_configs[ "inner_parallelism_size"]) - self._sharding_segment_strategy = str(self.user_defined_strategy.sharding_configs[ - "sharding_segment_strategy"]) + self._sharding_segment_strategy = str( + self.user_defined_strategy.sharding_configs[ + "sharding_segment_strategy"]) if self._sharding_segment_strategy == "broadcast_size": - self._broadcast_MB = int(self.user_defined_strategy.sharding_configs[ - "broadcast_MB"]) - assert self._broadcast_MB > 0, "segment size should larger than zero !" + self._broadcast_MB = int( + self.user_defined_strategy.sharding_configs["broadcast_MB"]) + assert self._broadcast_MB > 0, "segment size should larger than zero !" elif self._sharding_segment_strategy == "anchors": self._sharding_segment_anchors = self.user_defined_strategy.sharding_configs[ "sharding_segment_anchors"] - assert len(self._sharding_segment_anchors) > 0, "you should set the sharding segment anchors !" + assert len(self._sharding_segment_anchors + ) > 0, "you should set the sharding segment anchors !" self._backward_remain_anchors = self._sharding_segment_anchors[:] self._forward_remain_anchors = [] else: raise NotImplementedError( - "the sharding segment strategy [{}] is not implemented".format(str(self._sharding_segment_strategy))) - self._gradient_merge_acc_step = int(self.user_defined_strategy.sharding_configs[ + "the sharding segment strategy [{}] is not implemented".format( + str(self._sharding_segment_strategy))) + self._gradient_merge_acc_step = int( + self.user_defined_strategy.sharding_configs[ "gradient_merge_acc_step"]) self._grad2merged_grad = dict() - + if self.inner_opt is None: raise ValueError( "self.inner_opt of ShardingOptimizer should not be None.") @@ -158,14 +161,16 @@ def minimize_impl(self, self._initialization_broadcast(startup_program) logging.info("########### after inner prune") - with open("main_before_gm_%d" % self.role_maker._worker_index(), 'w') as f: + with open("main_before_gm_%d" % self.role_maker._worker_index(), + 'w') as f: f.writelines(str(self._main_program)) - + # step6: optional gradient merge if self._gradient_merge_acc_step > 1: self._sharding_gradient_merge(main_block) - with open("main_after_gm_%d" % self.role_maker._worker_index(), 'w') as f: + with open("main_after_gm_%d" % self.role_maker._worker_index(), + 'w') as f: f.writelines(str(self._main_program)) # # check op dependecy @@ -176,7 +181,6 @@ def minimize_impl(self, self._wait() logging.info("########### after inner wait") - return optimize_ops, params_grads def _set_up(self, params_grads): @@ -199,8 +203,7 @@ def _set_up(self, params_grads): if self._as_outer_parallelism: self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, - self.mp_group_endpoints, self.mp_rank, - self.mp_group_id, False) + self.mp_group_endpoints, self.mp_rank, self.mp_group_id, False) # dp if self.hybrid_dp: @@ -250,7 +253,7 @@ def _split_program(self, block): if self._sharding_segment_strategy == "broadcast_size": if segment._param_mem >= self._broadcast_MB: segment = self.collect_segment(segment, op_idx, block) - + elif self._sharding_segment_strategy == "anchors": if int(op.attr('op_role')) == int(OpRole.Backward): for input_name in op.desc.input_arg_names(): @@ -260,25 +263,34 @@ def _split_program(self, block): if ".cast_fp16@GRAD" not in input_name: continue else: - input_name = input_name[:input_name.find(".cast_fp16@GRAD")] + input_name = input_name[:input_name.find( + ".cast_fp16@GRAD")] if input_name in self._backward_remain_anchors: logging.info("backward segment:") - logging.info("op [{}] input [{}] output [{}]".format(op.desc.type() ,op.desc.input_arg_names(), op.desc.output_arg_names())) - segment = self.collect_segment(segment, op_idx, block) - assert input_name not in self._forward_remain_anchors, "segment anchor [{}] met twice !".format(input_name) + logging.info("op [{}] input [{}] output [{}]". + format(op.desc.type(), + op.desc.input_arg_names(), + op.desc.output_arg_names())) + segment = self.collect_segment(segment, op_idx, + block) + assert input_name not in self._forward_remain_anchors, "segment anchor [{}] met twice !".format( + input_name) self._backward_remain_anchors.remove(input_name) self._forward_remain_anchors.append(input_name) elif int(op.attr('op_role')) == int(OpRole.Forward): for output_name in op.desc.output_arg_names(): if output_name in self._forward_remain_anchors: logging.info("forward segment:") - logging.info("op [{}] input [{}] output [{}]".format(op.desc.type() ,op.desc.input_arg_names(), op.desc.output_arg_names())) - segment = self.collect_segment(segment, op_idx, block) + logging.info("op [{}] input [{}] output [{}]". + format(op.desc.type(), + op.desc.input_arg_names(), + op.desc.output_arg_names())) + segment = self.collect_segment(segment, op_idx, + block) self._forward_remain_anchors.remove(output_name) - - # find broadcast vars +# find broadcast vars for input_name in op.desc.input_arg_names(): if input_name not in self._broadcast_vars: continue @@ -299,9 +311,15 @@ def _split_program(self, block): broadcast_var_base_name = input_name if "subprog" in broadcast_var_base_name: # remove suffix - broadcast_var_base_name = broadcast_var_base_name[:broadcast_var_base_name.find(".subprog")] + broadcast_var_base_name = broadcast_var_base_name[: + broadcast_var_base_name. + find( + ".subprog" + )] - var2broadcast_time[broadcast_var_base_name] = var2broadcast_time.get(broadcast_var_base_name, 0) + 1 + var2broadcast_time[ + broadcast_var_base_name] = var2broadcast_time.get( + broadcast_var_base_name, 0) + 1 segment._param2broadcast[input_name] = broadcast_var_name segment._broadcast_vars.append((broadcast_var_name, @@ -334,17 +352,25 @@ def _split_program(self, block): self._segments.insert(0, segment) if self._sharding_segment_strategy == "anchors": - assert len(self._forward_remain_anchors) == 0, "remain anchors {}".format(self._forward_remain_anchors) - assert len(self._backward_remain_anchors) == 0, "remain anchors {}".format(self._backward_remain_anchors) - - for varname in sorted(var2broadcast_time, key=var2broadcast_time.get, reverse=True): - logging.info("Sharding broadcast: [{}] times [{}]".format(var2broadcast_time[varname], varname)) + assert len( + self._forward_remain_anchors) == 0, "remain anchors {}".format( + self._forward_remain_anchors) + assert len( + self._backward_remain_anchors) == 0, "remain anchors {}".format( + self._backward_remain_anchors) + + for varname in sorted( + var2broadcast_time, key=var2broadcast_time.get, reverse=True): + logging.info("Sharding broadcast: [{}] times [{}]".format( + var2broadcast_time[varname], varname)) for idx_ in range(len(self._segments)): logging.info("segment [{}] :".format(idx_)) - logging.info("start op: [{}] [{}]".format(block.ops[self._segments[idx_]._start_idx].desc.type(), - block.ops[self._segments[idx_]._start_idx].desc.input_arg_names())) - logging.info("end op: [{}] [{}]".format(block.ops[self._segments[idx_]._end_idx].desc.type(), - block.ops[self._segments[idx_]._end_idx].desc.input_arg_names())) + logging.info("start op: [{}] [{}]".format(block.ops[self._segments[ + idx_]._start_idx].desc.type(), block.ops[self._segments[ + idx_]._start_idx].desc.input_arg_names())) + logging.info("end op: [{}] [{}]".format(block.ops[self._segments[ + idx_]._end_idx].desc.type(), block.ops[self._segments[ + idx_]._end_idx].desc.input_arg_names())) return def _prune_main_program(self, block): @@ -450,16 +476,15 @@ def _add_broadcast_allreduce(self, block): if self._gradient_merge_acc_step <= 1: if self.hybrid_dp and len(shard_allredue_vars) >= 1: insert_sync_comm_ops(block, self._segments[-1]._end_idx, - self.dp_ring_id, shard_allredue_vars) + self.dp_ring_id, shard_allredue_vars) insert_allreduce_ops(block, self._segments[-1]._end_idx, - self.dp_ring_id, shard_allredue_vars) + self.dp_ring_id, shard_allredue_vars) # gradient merge else: self.create_persistable_gradients_and_insert_merge_ops( - block, - self._startup_program.global_block(), - self._segments[-1]._end_idx, - shard_allredue_vars, + block, + self._startup_program.global_block(), + self._segments[-1]._end_idx, shard_allredue_vars, self._shard) insert_sync_comm_ops(block, self._segments[-1]._end_idx, @@ -467,8 +492,8 @@ def _add_broadcast_allreduce(self, block): self._segments[-1]._allreduce_vars) # allreduce --> reduce insert_reduce_ops(block, self._segments[-1]._end_idx, - self.sharding_ring_id, - self._segments[-1]._allreduce_vars, self._shard) + self.sharding_ring_id, + self._segments[-1]._allreduce_vars, self._shard) for idx, segment in reversed(list(enumerate(self._segments))): allreduce_vars = self._segments[ @@ -510,24 +535,28 @@ def _add_broadcast_allreduce(self, block): if self._gradient_merge_acc_step <= 1: if self.hybrid_dp and len(shard_allredue_vars) >= 1: - insert_sync_comm_ops(block, segment._end_idx, self.dp_ring_id, - shard_allredue_vars) + insert_sync_comm_ops(block, segment._end_idx, + self.dp_ring_id, shard_allredue_vars) broad_cast_vars = [x[0] for x in broadcast_vars] if len(broad_cast_vars) > 0: insert_sync_comm_ops(block, segment._end_idx, - self.sharding_ring_id, broad_cast_vars) + self.sharding_ring_id, + broad_cast_vars) else: - comm_dep_vars = allreduce_vars + [x[0] for x in broadcast_vars] + comm_dep_vars = allreduce_vars + [ + x[0] for x in broadcast_vars + ] if len(comm_dep_vars) > 0: insert_sync_comm_ops(block, segment._end_idx, - self.sharding_ring_id, comm_dep_vars) + self.sharding_ring_id, + comm_dep_vars) # gradient merge else: broad_cast_vars = [x[0] for x in broadcast_vars] if len(broad_cast_vars) > 0: insert_sync_comm_ops(block, segment._end_idx, - self.sharding_ring_id, broad_cast_vars) + self.sharding_ring_id, broad_cast_vars) calc_dep_vars = fill_constant_vars + [ k for k, v in cast_ops.items() @@ -548,11 +577,9 @@ def _add_broadcast_allreduce(self, block): # gradient merge if self._gradient_merge_acc_step > 1: self.create_persistable_gradients_and_insert_merge_ops( - block, - self._startup_program.global_block(), - segment._start_idx, - shard_allredue_vars, - self._shard) + block, + self._startup_program.global_block(), segment._start_idx, + shard_allredue_vars, self._shard) insert_broadcast_ops(block, segment._start_idx, self.sharding_ring_id, broadcast_vars) @@ -561,18 +588,18 @@ def _add_broadcast_allreduce(self, block): # dp if self._gradient_merge_acc_step <= 1: if self.hybrid_dp and len(shard_allredue_vars) >= 1: - insert_allreduce_ops(block, segment._start_idx, self.dp_ring_id, - shard_allredue_vars) + insert_allreduce_ops(block, segment._start_idx, + self.dp_ring_id, shard_allredue_vars) insert_sync_comm_ops(block, segment._start_idx, - self.sharding_ring_id, allreduce_vars) + self.sharding_ring_id, allreduce_vars) # gradient merge else: insert_sync_comm_ops(block, segment._start_idx, - self.sharding_ring_id, allreduce_vars) + self.sharding_ring_id, allreduce_vars) # sharding # allreduce --> reduce - insert_reduce_ops(block, segment._start_idx, - self.sharding_ring_id, allreduce_vars, self._shard) + insert_reduce_ops(block, segment._start_idx, self.sharding_ring_id, + allreduce_vars, self._shard) block._sync_with_cpp() @@ -660,7 +687,6 @@ def _init_comm(self): self.mp_group_size = self.sharding_group_size self.mp_group_endpoints = self.sharding_group_endpoints[:] - logging.info("Using Sharing&DP mode !") else: if self._as_outer_parallelism: @@ -685,27 +711,6 @@ def _init_comm(self): self.mp_group_endpoints = self.endpoints[:] logging.info("Using Sharing as Outer parallelism mode !") - # print( - # "init the nccl comm for megatron paramllelism, this should be done in Megatron Metaoptimizer" - # ) - # partition_idx = self.global_rank // self._inner_parallelism_size - # magetron_endpoints = self.endpoints[ - # partition_idx * self._inner_parallelism_size:partition_idx * - # self._inner_parallelism_size + self._inner_parallelism_size] - # magetron_rank = self.global_rank % self._inner_parallelism_size - - # self._collective_helper._init_communicator( - # program=self._startup_program, - # current_endpoint=self.current_endpoint, - # endpoints=magetron_endpoints, - # rank=magetron_rank, - # ring_id=0, - # wait_port=True) - # logging.info("megatron group size: {}".format( - # self._inner_parallelism_size)) - # logging.info("megatron rank: {}".format(magetron_rank)) - # logging.info("megatron endpoints: {}".format( - # magetron_endpoints)) else: self.sharding_ring_id = 0 self.sharding_rank = self.global_rank @@ -726,11 +731,12 @@ def _init_comm(self): self.dp_group_endpoints = None logging.info("global word size: {}".format(self.global_word_size)) - logging.info("global rank: {}".format(self.global_rank)) + logging.info("global rank: {}".format(self.global_rank)) logging.info("sharding group_size: {}".format(self.sharding_group_size)) logging.info("sharding rank: {}".format(self.sharding_rank)) - logging.info("current model parallelism group_size: {}".format(self.mp_group_size)) - logging.info("current model parallelism rank: {}".format(self.mp_rank)) + logging.info("current model parallelism group_size: {}".format( + self.mp_group_size)) + logging.info("current model parallelism rank: {}".format(self.mp_rank)) logging.info("dp group size: {}".format(self.dp_group_size)) logging.info("dp rank: {}".format(self.dp_rank)) logging.info("current endpoint: {}".format(self.current_endpoint)) @@ -769,12 +775,17 @@ def _initialization_broadcast(self, startup_prog): OP_ROLE_KEY: OpRole.Forward}) # sharding gradient merge - def create_persistable_gradients_and_insert_merge_ops(self, main_block, startup_block, insert_idx, grad_names, shard): + def create_persistable_gradients_and_insert_merge_ops( + self, main_block, startup_block, insert_idx, grad_names, shard): for grad_name in grad_names: - assert get_grad_device(grad_name, shard) == shard.worker_idx, "try to merge gradient not belong to current shard: [{}]".format(grad_name) + assert get_grad_device( + grad_name, shard + ) == shard.worker_idx, "try to merge gradient not belong to current shard: [{}]".format( + grad_name) persistable_grad_name = grad_name + '@GradiantMerge' - assert grad_name not in self._grad2merged_grad, "grad [{}] already in grad2merged_grad, maybe you meet sharing weight case !".format(grad_name) + assert grad_name not in self._grad2merged_grad, "grad [{}] already in grad2merged_grad, maybe you meet sharing weight case !".format( + grad_name) self._grad2merged_grad[grad_name] = persistable_grad_name grad_var = main_block.var(grad_name) # create var @@ -796,9 +807,11 @@ def create_persistable_gradients_and_insert_merge_ops(self, main_block, startup_ inputs={'X': grad_name, 'Y': gradient_merge_var}, outputs={'Out': gradient_merge_var}, - attrs={'axis': -1, - 'use_mkldnn': False, - OP_ROLE_KEY: OpRole.Backward}) + attrs={ + 'axis': -1, + 'use_mkldnn': False, + OP_ROLE_KEY: OpRole.Backward + }) # startup initialization startup_block.append_op( @@ -813,8 +826,6 @@ def create_persistable_gradients_and_insert_merge_ops(self, main_block, startup_ main_block._sync_with_cpp() startup_block._sync_with_cpp() - - def _create_gm_cond(self, main_block): # Add const var acc_step_var = layers.create_global_var( @@ -857,16 +868,18 @@ def _create_gm_cond(self, main_block): inputs={'X': [current_step_var]}, outputs={'Out': [current_step_var]}, attrs={'step': float(1), - OP_ROLE_KEY: OpRole.Optimize}) + OP_ROLE_KEY: OpRole.Optimize}) main_block.append_op( type='elementwise_mod', inputs={'X': current_step_var, 'Y': acc_step_var}, outputs={'Out': current_step_var}, - attrs={'axis': -1, - OP_ROLE_KEY: OpRole.Optimize, - 'use_mkldnn': False}) + attrs={ + 'axis': -1, + OP_ROLE_KEY: OpRole.Optimize, + 'use_mkldnn': False + }) # cond_var = (step_var == 0) main_block.append_op( @@ -874,9 +887,8 @@ def _create_gm_cond(self, main_block): inputs={'X': current_step_var, 'Y': zero_var}, outputs={'Out': cond_var}, - attrs={OP_ROLE_KEY: OpRole.Optimize} - ) - paddle.static.Print(current_step_var, message="in FWBW last conditional") + attrs={OP_ROLE_KEY: OpRole.Optimize}) + # paddle.static.Print(current_step_var, message="in FWBW last conditional") return cond_var def _true_apply_gradient(self): @@ -896,7 +908,7 @@ def _true_apply_gradient(self): cur_block_idx = self._main_program.current_block_idx cur_block = self._main_program.current_block() self.cond_block = self._main_program.current_block() - + # cur_block's forward_block & backward_block is itself cur_block._set_forward_block_idx(cur_block_idx) @@ -909,11 +921,12 @@ def _true_apply_gradient(self): type='c_allreduce_sum', inputs={'X': merged_grad_var}, outputs={'Out': merged_grad_var}, - attrs={'ring_id': self.dp_ring_id, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Optimize}) - print("after allreduce grad@gradientmerge ") - + attrs={ + 'ring_id': self.dp_ring_id, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize + }) + # grad@gradientmerge / acc_step for grad, merged_grad in self._grad2merged_grad.items(): # grad /= k_steps @@ -928,7 +941,6 @@ def _true_apply_gradient(self): 'bias_after_scale': False, OP_ROLE_KEY: OpRole.Optimize }) - print("after allreduce grad@gradientmerge / acc_step") # re-create optimize ops already_moved_var_names = [] @@ -936,35 +948,37 @@ def _true_apply_gradient(self): new_op_desc = cur_block.desc.append_op() new_op_desc.copy_from(op_desc) - if op_desc.type in ["check_finite_and_unscale", "update_loss_scaling"]: - print("check @@@@@: ", op_desc.type) - print("input @@@@@: ", op_desc.output_arg_names) - for input_name in new_op_desc.input_arg_names(): if input_name in self._grad2merged_grad: - new_op_desc._rename_input(input_name, self._grad2merged_grad[input_name]) + new_op_desc._rename_input( + input_name, self._grad2merged_grad[input_name]) for output_name in new_op_desc.output_arg_names(): if output_name in self._grad2merged_grad: - new_op_desc._rename_output(output_name, self._grad2merged_grad[output_name]) - + new_op_desc._rename_output( + output_name, self._grad2merged_grad[output_name]) + # move non temp optimize vars from block0 to cond block - if output_name not in already_moved_var_names and output_name not in self._grad2merged_grad.keys(): + if output_name not in already_moved_var_names and output_name not in self._grad2merged_grad.keys( + ): var_ = self._main_program.global_block().var(output_name) if not var_.persistable: - print("remove: ", output_name) + logging.info( + "gradient merge move non persist var from block0: ", + output_name) # move name_ = var_.name shape_ = var_.shape type_ = var_.dtype - self._main_program.global_block()._remove_var(var_.name, sync=False) + self._main_program.global_block()._remove_var( + var_.name, sync=False) self.cond_block.create_var( - name=name_, + name=name_, shape=shape_, dtype=type_, persistable=False) already_moved_var_names.append(name_) - + self._main_program.global_block()._sync_with_cpp() cur_block._sync_with_cpp() @@ -981,12 +995,10 @@ def _true_apply_gradient(self): OP_ROLE_KEY: OpRole.Optimize }) - lr_var = main_block.var("@LR_DECAY_COUNTER@") - paddle.static.Print(lr_var, message="in OPTIMIZE last conditional") - + # lr_var = main_block.var("gradient_merge_current_step") + # paddle.static.Print(lr_var, message="in OPTIMIZE last conditional") def _sharding_gradient_merge(self, main_block): - """ copy all optimize ops in origin main block remove all optimize ops in origin main block @@ -1007,25 +1019,19 @@ def _sharding_gradient_merge(self, main_block): self.original_optimize_ops_desc.append(tmp_op_desc) main_block._remove_op(op_idx, sync=False) tmp_copy_block._sync_with_cpp() - self.original_optimize_ops_desc = list(reversed(self.original_optimize_ops_desc)) - print("original_optimize_ops_desc :") - for desc in self.original_optimize_ops_desc: - print(desc.type(), desc.input_arg_names()) + self.original_optimize_ops_desc = list( + reversed(self.original_optimize_ops_desc)) # back to block 0 self._main_program._rollback() - print("after first rollback") # create cond vars and ops at the end of block 0 cond = self._create_gm_cond(main_block) - print("after _create_gm_cond") - # create cond block cond_block = self._main_program._create_block() self._true_apply_gradient() - print("after _true_apply_gradient") - + # back to block 0 self._main_program._rollback() @@ -1044,15 +1050,3 @@ def _sharding_gradient_merge(self, main_block): 'sub_block': cond_block, 'is_scalar_condition': True, }) - - - - - - - - - - - - \ No newline at end of file From 726525c293b8541b6b43fc4bc61a1ce15b799d8a Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Mon, 29 Mar 2021 14:49:19 +0800 Subject: [PATCH 16/24] sharding: revise api --- .../framework/distributed_strategy.proto | 14 +-- .../meta_optimizers/sharding/fp16_helper.py | 17 ++- .../sharding/gradient_clip_helper.py | 28 ++--- .../fleet/meta_optimizers/sharding/utils.py | 57 ++++----- .../meta_optimizers/sharding_optimizer.py | 108 +++++++++--------- .../tests/unittests/dist_sharding_save.py | 5 +- .../unittests/fleet_meta_optimizer_base.py | 5 +- .../test_fleet_sharding_meta_optimizer.py | 83 +++++++------- 8 files changed, 160 insertions(+), 157 deletions(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 67923cf6cf441e..7703c60cddf918 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -29,14 +29,14 @@ message RecomputeConfig { } message ShardingConfig { - optional float broadcast_MB = 1 [ default = 32.0 ]; + optional float segment_broadcast_MB = 1 [ default = 32.0 ]; optional bool hybrid_dp = 2 [ default = false ]; - optional int32 sharding_group_size = 3 [ default = 8 ]; - optional bool as_outer_parallelism = 4 [ default = false ]; - optional int32 inner_parallelism_size = 5 [ default = 8 ]; - optional string sharding_segment_strategy = 6 [ default = 'broadcast_size' ]; - repeated string sharding_segment_anchors = 7; - optional int32 gradient_merge_acc_step = 8 [ default = 1 ]; + optional int32 sharding_degree = 3 [ default = 8 ]; + optional int32 mp_degree = 4 [ default = 1 ]; + optional string sharding_segment_strategy = 5 + [ default = 'segment_broadcast_MB' ]; + repeated string segment_anchors = 6; + optional int32 gradient_merge_acc_step = 7 [ default = 1 ]; } message AMPConfig { diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py index a9f1327cb19a01..e946ed5fb3fbe6 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py @@ -118,8 +118,13 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id): # the grad checking should take the all and only param in the current shard to_check_param = set(reversed_x_paramname) - should_check_param = set(shard.global_params).intersection(set([param for param, worker_idx in shard.global_param2device.items() if worker_idx == shard.worker_idx])) - assert to_check_param == should_check_param, "amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(should_check_param - to_check_param, to_check_param - should_check_param) + should_check_param = set(shard.global_params).intersection( + set([param for param, worker_idx in shard.global_param2device.items() \ + if worker_idx == shard.worker_idx])) + assert to_check_param == should_check_param, "amp \ + check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format( + should_check_param - to_check_param, + to_check_param - should_check_param) if update_loss_scaling_op_idx == -1: return @@ -150,9 +155,11 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id): type='c_allreduce_max', inputs={'X': inf_var_int32}, outputs={'Out': inf_var_int32}, - attrs={'ring_id': ring_id, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Optimize}) + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize + }) # comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3, # ring_id, [inf_var_int32]) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index d0f8e13cdbf9cf..834ff74bb212b7 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -45,7 +45,7 @@ def prune_gradient_clip(self, block, shard): if shard.is_param(param_name) and \ not shard.has_param(param_name): deperate_op = True - elif shard.is_param(param_name) : + elif shard.is_param(param_name): reversed_x_paramname.append(param_name) if deperate_op: @@ -68,19 +68,12 @@ def prune_gradient_clip(self, block, shard): for input_name in op.desc.input_arg_names(): if input_name not in deperated_vars: reversed_inputs.append(input_name) - + op.desc.set_input("X", reversed_inputs) assert (len(op.desc.output_arg_names()) == 1) sum_res = op.desc.output_arg_names()[0] # this allreduce should not overlap with calc and should be scheduled in calc stream - # block._insert_op_without_sync( - # idx + 1, - # type='c_sync_comm_stream', - # inputs={'X': sum_res}, - # outputs={'Out': sum_res}, - # attrs={'ring_id': 0, - # OP_ROLE_KEY: OpRole.Optimize}) block._insert_op_without_sync( idx + 1, type='c_allreduce_sum', @@ -92,18 +85,17 @@ def prune_gradient_clip(self, block, shard): 'use_calc_stream': True, OP_ROLE_KEY: OpRole.Optimize, }) - # block._insert_op_without_sync( - # idx + 1, - # type='c_sync_calc_stream', - # inputs={'X': sum_res}, - # outputs={'Out': sum_res}, - # attrs={OP_ROLE_KEY: OpRole.Optimize}) # the grad sum here should take the all and only param in the current shard to_check_param = set(reversed_x_paramname) - should_check_param = set(shard.global_params).intersection(set([param for param, worker_idx in shard.global_param2device.items() if worker_idx == shard.worker_idx])) - assert to_check_param == should_check_param, "amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(should_check_param - to_check_param, to_check_param - should_check_param) - + should_check_param = set(shard.global_params).intersection(set( + [param for param, worker_idx in shard.global_param2device.items() \ + if worker_idx == shard.worker_idx])) + assert to_check_param == should_check_param, "amp check_finite_and_unscale \ + checking miss [{}] and got unexpected [{}]".format( + should_check_param - to_check_param, + to_check_param - should_check_param) + for var_name in deperated_vars: block._remove_var(var_name, sync=False) block._sync_with_cpp() diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index 0ee43e6157725d..26944d9798aa6b 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -103,7 +103,7 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): idx_gradient_clip_allreduce = -1 for idx, op in enumerate(block.ops): - if op.type == "c_allreduce_sum" or op.type == "c_reduce_sum" : + if op.type == "c_allreduce_sum" or op.type == "c_reduce_sum": if op.all_attrs()["use_calc_stream"] == False: ring_id = op.desc.attr("ring_id") var_name = op.desc.input_arg_names()[0] @@ -137,7 +137,7 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): var_name] == 0: dp_grads_status[var_name] = 1 - elif op.type == "c_allreduce_sum" or op.type == "c_reduce_sum" : + elif op.type == "c_allreduce_sum" or op.type == "c_reduce_sum": if op.all_attrs()["use_calc_stream"] == False: var_name = op.desc.input_arg_names()[0] ring_id = op.desc.attr("ring_id") @@ -192,8 +192,9 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): raise ValueError("There should be a sync_comm op " "after allreduce the Var: {}".format( input_name)) - raise ValueError("The reduce output grad [{}] should NOT be be used in Non-root rank.".format( - input_name)) + raise ValueError( + "The reduce output grad [{}] should NOT be be used in Non-root rank.". + format(input_name)) if input_name in dp_grads_status: if dp_ring_id == -1: if dp_grads_status[input_name] != 3: @@ -205,7 +206,6 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): raise ValueError( "The grad in shard should be allreduce and sync" "twice before usage {}".format(input_name)) - for output_name in op.desc.output_arg_names(): if output_name in vars_status and \ @@ -338,6 +338,7 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars): return + def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard): """ _add_allreduce_ops @@ -350,14 +351,15 @@ def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard): type='c_reduce_sum', inputs={'X': var}, outputs={'Out': var}, - attrs={'ring_id': ring_id, - 'root_id': root_id, - OP_ROLE_KEY: OpRole.Backward}) + attrs={ + 'ring_id': ring_id, + 'root_id': root_id, + OP_ROLE_KEY: OpRole.Backward + }) return - def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): """ _add_broadcast_ops @@ -461,7 +463,7 @@ def comm_analyse(main_program): count)) -def add_sync_comm(program, nccl_ids): +def add_sync_comm(program, sharding_ring_id): """ When clone a test prog by clone from the sharding main prog, part of the sync_comm op maybe be pruned by mistake, this function @@ -471,9 +473,7 @@ def add_sync_comm(program, nccl_ids): #NOTE (liangjianzhong): only support one comm stream by now, use more than one # comm streams will cause error. should be revise in future. - assert isinstance( - nccl_ids, list - ), "the second argument of this function should be a list of nccl_ids" + assert sharding_ring_id >= 0, "sharding_ring_id should larger than zero" block = program.global_block() not_sync_vars = set([]) for op in block.ops: @@ -484,15 +484,14 @@ def add_sync_comm(program, nccl_ids): for input_name in op.desc.input_arg_names(): not_sync_vars.remove(input_name) if not_sync_vars: - for nccl_id in nccl_ids: - block.append_op( - type='c_sync_comm_stream', - inputs={'X': list(not_sync_vars)}, - outputs={'Out': list(not_sync_vars)}, - attrs={ - 'ring_id': nccl_id, - 'op_role': core.op_proto_and_checker_maker.OpRole.Forward - }) + block.append_op( + type='c_sync_comm_stream', + inputs={'X': list(not_sync_vars)}, + outputs={'Out': list(not_sync_vars)}, + attrs={ + 'ring_id': sharding_ring_id, + 'op_role': core.op_proto_and_checker_maker.OpRole.Forward + }) return @@ -514,7 +513,7 @@ def is_opt_vars(var): if var.name.endswith(check): return True return False - + def is_gradient_merge_vars(var): # NOTE(liangjianzhong): to revise save/load logic in framework instead of write this naive rule @@ -525,7 +524,8 @@ def is_trainable(var): paddle.fluid.framework.Parameter) and var.trainable def sharding_predicate(var): - return is_trainable(var) or is_opt_vars(var) or is_gradient_merge_vars(var) + return is_trainable(var) or is_opt_vars(var) or is_gradient_merge_vars( + var) if int(os.environ.get('PADDLE_TRAINER_ID', 0)) == 0: paddle.fluid.io.save_persistables( @@ -540,16 +540,19 @@ def sharding_predicate(var): return + def get_grad_device(grad_name, shard): - assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(grad_name) + assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format( + grad_name) base_name = None # mind the traversal order possible_suffixes = ['.cast_fp16@GRAD', '@GRAD'] for suffix in possible_suffixes: - if suffix in grad_name : + if suffix in grad_name: base_name = re.sub(suffix, '', grad_name) break - assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(base_name) + assert base_name in shard.global_param2device, "[{}] should be a param variable.".format( + base_name) return shard.global_param2device[base_name] diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 5864c060ed0778..91a37d897c137e 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -60,8 +60,7 @@ def __init__(self, optimizer): self._shard = Shard() # use sharding as outer parallelism (e.g. inner:Megatron & outer sharding) - self._as_outer_parallelism = False - self._inner_parallelism_size = None + self.mp_degree = 1 def _can_apply(self): if not self.role_maker._is_collective: @@ -76,7 +75,7 @@ def _disable_strategy(self, dist_strategy): def _enable_strategy(self, dist_strategy, context): dist_strategy.sharding = True - dist_strategy.sharding_configs = {"broadcast_MB": 32} + dist_strategy.sharding_configs = {"segment_broadcast_MB": 32} def minimize_impl(self, loss, @@ -90,22 +89,19 @@ def minimize_impl(self, self._nrings_dp = 1 self.hybrid_dp = self.user_defined_strategy.sharding_configs[ "hybrid_dp"] - self._as_outer_parallelism = self.user_defined_strategy.sharding_configs[ - "as_outer_parallelism"] - self._inner_parallelism_size = int( - self.user_defined_strategy.sharding_configs[ - "inner_parallelism_size"]) + self.mp_degree = int(self.user_defined_strategy.sharding_configs[ + "mp_degree"]) self._sharding_segment_strategy = str( self.user_defined_strategy.sharding_configs[ "sharding_segment_strategy"]) - if self._sharding_segment_strategy == "broadcast_size": - self._broadcast_MB = int( - self.user_defined_strategy.sharding_configs["broadcast_MB"]) + if self._sharding_segment_strategy == "segment_broadcast_MB": + self._broadcast_MB = self.user_defined_strategy.sharding_configs[ + "segment_broadcast_MB"] assert self._broadcast_MB > 0, "segment size should larger than zero !" - elif self._sharding_segment_strategy == "anchors": + elif self._sharding_segment_strategy == "segment_anchors": self._sharding_segment_anchors = self.user_defined_strategy.sharding_configs[ - "sharding_segment_anchors"] + "segment_anchors"] assert len(self._sharding_segment_anchors ) > 0, "you should set the sharding segment anchors !" self._backward_remain_anchors = self._sharding_segment_anchors[:] @@ -149,8 +145,8 @@ def minimize_impl(self, # step4: insert reduce_sum for grad grad_scale_coeff = self.role_maker._worker_num() - if self._as_outer_parallelism: - grad_scale_coeff = grad_scale_coeff / self._inner_parallelism_size + if self.mp_degree: + grad_scale_coeff = grad_scale_coeff / self.mp_degree insert_scale_loss_grad_ops(main_block, scale=1.0 / grad_scale_coeff) main_block._sync_with_cpp() @@ -200,7 +196,7 @@ def _set_up(self, params_grads): self.sharding_ring_id, False) # inner & outer model parallelism - if self._as_outer_parallelism: + if self.mp_degree > 1: self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, self.mp_group_endpoints, self.mp_rank, self.mp_group_id, False) @@ -217,7 +213,7 @@ def _set_up(self, params_grads): # step 2: split params self._params = set([x[0].name for x in params_grads]) self._shard.setup(params_grads, self.sharding_rank, - self.sharding_group_size) + self.sharding_degree) # step 3: get broadcast vars self._broadcast_vars = self._shard.find_broadcast_params( @@ -250,11 +246,11 @@ def _split_program(self, block): for op_idx in reversed(range(last_backward_op_idx)): op = block.ops[op_idx] assert (int(op.attr('op_role')) != int(OpRole.Optimize)) - if self._sharding_segment_strategy == "broadcast_size": + if self._sharding_segment_strategy == "segment_broadcast_MB": if segment._param_mem >= self._broadcast_MB: segment = self.collect_segment(segment, op_idx, block) - elif self._sharding_segment_strategy == "anchors": + elif self._sharding_segment_strategy == "segment_anchors": if int(op.attr('op_role')) == int(OpRole.Backward): for input_name in op.desc.input_arg_names(): @@ -351,7 +347,7 @@ def _split_program(self, block): segment._start_idx = 0 self._segments.insert(0, segment) - if self._sharding_segment_strategy == "anchors": + if self._sharding_segment_strategy == "segment_anchors": assert len( self._forward_remain_anchors) == 0, "remain anchors {}".format( self._forward_remain_anchors) @@ -389,7 +385,7 @@ def _prune_main_program(self, block): # NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism # group. and each Data Parallelism group should have its own sync of FoundInfinite Model_Paramllelism_ring_id = self.sharding_ring_id - if self._as_outer_parallelism: + if self.mp_degree > 1: Model_Paramllelism_ring_id = self.mp_group_id FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param, Model_Paramllelism_ring_id) @@ -653,91 +649,91 @@ def _prune_startup_program(self, block): def _init_comm(self): if self.hybrid_dp: - assert self._as_outer_parallelism == False, "hybrid dp is conflict when using sharding as outer parallelism" - self.sharding_group_size = self.user_defined_strategy.sharding_configs[ - "sharding_group_size"] + assert self.mp_degree <= 1, "hybrid dp is conflict when using sharding with megatron." + self.sharding_degree = self.user_defined_strategy.sharding_configs[ + "sharding_degree"] self.sharding_ring_id = 0 - self.sharding_rank = self.global_rank % self.sharding_group_size + self.sharding_rank = self.global_rank % self.sharding_degree - self.dp_group_size = self.global_word_size // self.sharding_group_size - self.dp_rank = self.global_rank // self.sharding_group_size + self.dp_degree = self.global_word_size // self.sharding_degree + self.dp_rank = self.global_rank // self.sharding_degree self.dp_ring_id = self.sharding_rank + 1 self.sharding_group_endpoints = [ ep for idx, ep in enumerate(self.endpoints) - if (idx // self.sharding_group_size) == self.dp_rank + if (idx // self.sharding_degree) == self.dp_rank ] self.dp_group_endpoints = [ ep for idx, ep in enumerate(self.endpoints) - if (idx % self.sharding_group_size) == self.sharding_rank + if (idx % self.sharding_degree) == self.sharding_rank ] - assert self.global_word_size > self.sharding_group_size, \ - "global_word_size: {} should be larger than sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size) - assert self.global_word_size % self.sharding_group_size == 0, \ - "global_word_size: {} should be divisible to the sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size) - assert self.dp_group_size * self.sharding_group_size == self.global_word_size, \ - "global_word_size: {} should be equal to the product of sharding_group_size: {} and dp_group_size: {}".format( + assert self.global_word_size > self.sharding_degree, \ + "global_word_size: {} should be larger than sharding_degree: {}".format(self.global_word_size, self.sharding_degree) + assert self.global_word_size % self.sharding_degree == 0, \ + "global_word_size: {} should be divisible to the sharding_degree: {}".format(self.global_word_size, self.sharding_degree) + assert self.dp_degree * self.sharding_degree == self.global_word_size, \ + "global_word_size: {} should be equal to the product of sharding_degree: {} and dp_degree: {}".format( self.global_word_size, - self.sharding_group_size, - self.dp_group_size) + self.sharding_degree, + self.dp_degree) # sharding parallelism is the only model parallelism in the current setting self.mp_group_id = self.sharding_ring_id self.mp_rank = self.sharding_rank - self.mp_group_size = self.sharding_group_size + self.mp_degree = self.sharding_degree self.mp_group_endpoints = self.sharding_group_endpoints[:] logging.info("Using Sharing&DP mode !") else: - if self._as_outer_parallelism: + if self.mp_degree > 1: self.sharding_ring_id = 1 - assert self.global_word_size > self._inner_parallelism_size, \ - "global_word_size: {} should be larger than inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size) - assert self.global_word_size % self._inner_parallelism_size == 0, \ - "global_word_size: {} should be divisible to the inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size) - self.sharding_rank = self.global_rank // self._inner_parallelism_size - self.sharding_group_size = self.role_maker._worker_num( - ) // self._inner_parallelism_size - _offset = self.global_rank % self._inner_parallelism_size + assert self.global_word_size > self.mp_degree, \ + "global_word_size: {} should be larger than mp_degree: {}".format(self.global_word_size, self.mp_degree) + assert self.global_word_size % self.mp_degree == 0, \ + "global_word_size: {} should be divisible to the mp_degree: {}".format(self.global_word_size, self.mp_degree) + self.sharding_rank = self.global_rank // self.mp_degree + self.sharding_degree = self.role_maker._worker_num( + ) // self.mp_degree + _offset = self.global_rank % self.mp_degree self.sharding_group_endpoints = [ ep for idx, ep in enumerate(self.endpoints) - if idx % self._inner_parallelism_size == _offset + if idx % self.mp_degree == _offset ] # the current entire model parallelism group is the combination of innert & sharding parallelism self.mp_group_id = 2 self.mp_rank = self.global_rank - self.mp_group_size = self.role_maker._worker_num() + self.mp_degree = self.role_maker._worker_num() self.mp_group_endpoints = self.endpoints[:] logging.info("Using Sharing as Outer parallelism mode !") else: self.sharding_ring_id = 0 self.sharding_rank = self.global_rank - self.sharding_group_size = self.role_maker._worker_num() + self.sharding_degree = self.role_maker._worker_num() self.sharding_group_endpoints = self.endpoints # sharding parallelism is the only model parallelism in the current setting self.mp_group_id = self.sharding_ring_id self.mp_rank = self.sharding_rank - self.mp_group_size = self.sharding_group_size + self.mp_degree = self.sharding_degree self.mp_group_endpoints = self.sharding_group_endpoints[:] logging.info("Using Sharing alone mode !") self.dp_ring_id = -1 self.dp_rank = -1 - self.dp_group_size = None + self.dp_degree = None self.dp_group_endpoints = None logging.info("global word size: {}".format(self.global_word_size)) logging.info("global rank: {}".format(self.global_rank)) - logging.info("sharding group_size: {}".format(self.sharding_group_size)) + logging.info("sharding degree: {}".format(self.sharding_degree)) logging.info("sharding rank: {}".format(self.sharding_rank)) - logging.info("current model parallelism group_size: {}".format( - self.mp_group_size)) + logging.info("current model parallelism degree: {}".format( + self.mp_degree)) logging.info("current model parallelism rank: {}".format(self.mp_rank)) - logging.info("dp group size: {}".format(self.dp_group_size)) + logging.info("dp group size: {}".format(self.dp_degree)) logging.info("dp rank: {}".format(self.dp_rank)) logging.info("current endpoint: {}".format(self.current_endpoint)) logging.info("global word endpoints: {}".format(self.endpoints)) diff --git a/python/paddle/fluid/tests/unittests/dist_sharding_save.py b/python/paddle/fluid/tests/unittests/dist_sharding_save.py index 22c930bf8948aa..d686c507e3b0c8 100755 --- a/python/paddle/fluid/tests/unittests/dist_sharding_save.py +++ b/python/paddle/fluid/tests/unittests/dist_sharding_save.py @@ -59,7 +59,10 @@ def runtime_main(): strategy = paddle.distributed.fleet.DistributedStrategy() strategy.sharding = True - strategy.sharding_configs = {"fuse_broadcast_MB": 0.2} + strategy.sharding_configs = { + "sharding_segment_strategy": "segment_broadcast_MB", + "segment_broadcast_MB": 0.2 + } optimizer = paddle.fluid.optimizer.Momentum( learning_rate=0.01, momentum=0.9) diff --git a/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py b/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py index 1c74a11cc4d2e6..fb2aaef7b3a668 100755 --- a/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py +++ b/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py @@ -146,7 +146,10 @@ def set_strategy(self, strategy, name): strategy.gradient_merge_configs = {"k_steps": 2, "avg": True} elif name == "sharding": strategy.sharding = True - strategy.sharding_configs = {"fuse_broadcast_MB": 0.2} + strategy.sharding_configs = { + "sharding_segment_strategy": "segment_broadcast_MB", + "segment_broadcast_MB": 0.2 + } elif name == "recompute-offload": strategy.recompute = True strategy.recompute_configs = { diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index 5da7e627f8707d..fc5de320db2872 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -45,6 +45,7 @@ def test_sharding_optimizer(self): "fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0", "fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0" ])) + self.assertEqual(ops, [ 'fill_constant', 'fill_constant', 'fill_constant', 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', @@ -55,9 +56,9 @@ def test_sharding_optimizer(self): 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', - 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', - 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', - 'c_sync_comm_stream', 'momentum', 'momentum', 'momentum' + 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', + 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'momentum', + 'momentum', 'momentum' ]) def test_sharding_amp_optimizer(self): @@ -82,6 +83,7 @@ def test_sharding_amp_optimizer(self): "fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0", "loss_scaling_0", "num_bad_steps_0", "num_good_steps_0" ])) + self.assertEqual(ops, [ 'cast', 'cast', 'cast', 'fill_constant', 'fill_constant', 'fill_constant', 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', @@ -94,11 +96,10 @@ def test_sharding_amp_optimizer(self): 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', - 'c_sync_calc_stream', 'c_allreduce_sum', 'c_allreduce_sum', - 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', - 'c_allreduce_sum', 'c_sync_comm_stream', 'cast', 'cast', 'cast', - 'check_finite_and_unscale', 'cast', 'c_sync_calc_stream', - 'c_allreduce_max', 'c_sync_comm_stream', 'cast', + 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', + 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', + 'c_sync_comm_stream', 'cast', 'cast', 'cast', + 'check_finite_and_unscale', 'cast', 'c_allreduce_max', 'cast', 'update_loss_scaling', 'momentum', 'momentum', 'momentum' ]) @@ -124,6 +125,7 @@ def test_sharding_recompute_optimizer(self): "fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0", "fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0" ])) + self.assertEqual(ops, [ 'fill_constant', 'fill_constant', 'fill_constant', 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', @@ -134,10 +136,9 @@ def test_sharding_recompute_optimizer(self): 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'mul', 'elementwise_add', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'mul', 'elementwise_add', 'tanh_grad', 'elementwise_add_grad', - 'mul_grad', 'c_sync_calc_stream', 'c_allreduce_sum', - 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', - 'c_allreduce_sum', 'c_allreduce_sum', 'c_sync_comm_stream', - 'momentum', 'momentum', 'momentum' + 'mul_grad', 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', + 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', + 'c_sync_comm_stream', 'momentum', 'momentum', 'momentum' ]) def test_sharding_amp_recompute_optimizer(self): @@ -167,29 +168,27 @@ def test_sharding_amp_recompute_optimizer(self): "fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0", "loss_scaling_0", "num_bad_steps_0", "num_good_steps_0" ])) - self.assertEqual(ops, [ - 'cast', 'cast', 'cast', 'fill_constant', 'fill_constant', + 'cast', 'cast', 'cast', 'cast', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', - 'cast', 'cast', 'mul', 'cast', 'elementwise_add', 'cast', 'tanh', - 'cast', 'cast', 'mul', 'elementwise_add', 'cast', 'tanh', 'cast', - 'mul', 'elementwise_add', 'softmax', 'cast', 'cross_entropy2', - 'mean', 'elementwise_mul', 'fill_constant', 'scale', - 'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', 'cast', - 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'cast', 'cast', - 'cast', 'mul', 'cast', 'elementwise_add', 'cast', 'tanh_grad', - 'cast', 'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul', - 'cast', 'elementwise_add', 'cast', 'tanh_grad', 'cast', + 'cast', 'mul', 'elementwise_add', 'cast', 'tanh', 'cast', 'mul', + 'elementwise_add', 'cast', 'tanh', 'cast', 'mul', 'elementwise_add', + 'softmax', 'cast', 'cross_entropy2', 'mean', 'elementwise_mul', + 'fill_constant', 'scale', 'elementwise_mul_grad', 'mean_grad', + 'cross_entropy_grad2', 'cast', 'softmax_grad', + 'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul', + 'elementwise_add', 'cast', 'tanh_grad', 'cast', + 'elementwise_add_grad', 'mul_grad', 'cast', 'mul', + 'elementwise_add', 'cast', 'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', - 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', - 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', - 'c_sync_comm_stream', 'cast', 'cast', 'cast', - 'check_finite_and_unscale', 'cast', 'c_sync_calc_stream', - 'c_allreduce_max', 'c_sync_comm_stream', 'cast', - 'update_loss_scaling', 'momentum', 'momentum', 'momentum' + 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', + 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'cast', + 'cast', 'cast', 'check_finite_and_unscale', 'cast', + 'c_allreduce_max', 'cast', 'update_loss_scaling', 'momentum', + 'momentum', 'momentum' ]) def test_sharding_weight_decay(self): @@ -227,10 +226,10 @@ def test_sharding_weight_decay(self): 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', - 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', - 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', - 'c_sync_comm_stream', 'scale', 'sum', 'scale', 'sum', 'scale', - 'sum', 'momentum', 'momentum', 'momentum' + 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', + 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'scale', + 'sum', 'scale', 'sum', 'scale', 'sum', 'momentum', 'momentum', + 'momentum' ]) def test_sharding_gradient_clip(self): @@ -253,6 +252,7 @@ def test_sharding_gradient_clip(self): "fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0", "fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0" ])) + self.assertEqual(ops, [ 'fill_constant', 'fill_constant', 'fill_constant', 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', @@ -263,14 +263,12 @@ def test_sharding_gradient_clip(self): 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', - 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', - 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', - 'c_sync_comm_stream', 'square', 'reduce_sum', 'square', - 'reduce_sum', 'square', 'reduce_sum', 'sum', 'c_sync_calc_stream', - 'c_allreduce_sum', 'c_sync_comm_stream', 'sqrt', 'fill_constant', - 'elementwise_max', 'elementwise_div', 'elementwise_mul', - 'elementwise_mul', 'elementwise_mul', 'momentum', 'momentum', - 'momentum' + 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', + 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'square', + 'reduce_sum', 'square', 'reduce_sum', 'square', 'reduce_sum', 'sum', + 'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max', + 'elementwise_div', 'elementwise_mul', 'elementwise_mul', + 'elementwise_mul', 'momentum', 'momentum', 'momentum' ]) def test_sharding_clone_for_test(self): @@ -281,7 +279,8 @@ def test_sharding_clone_for_test(self): self.optimizer(avg_cost, strategy, train_prog, startup_prog) sharding.utils.comm_analyse(train_prog) test_prog = train_prog.clone(for_test=True) - sharding.utils.add_sync_comm(test_prog, strategy) + # assume sharding_ring_id = 0 + sharding.utils.add_sync_comm(test_prog, 0) ops = [op.type for op in test_prog.global_block().ops] self.assertEqual(ops, [ From cb788cfeb62ec10d82c688ffbc4b666b75ba884e Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Mon, 29 Mar 2021 17:40:31 +0800 Subject: [PATCH 17/24] sharding: remove debug log --- .../meta_optimizers/sharding_optimizer.py | 49 ++++++------------- 1 file changed, 16 insertions(+), 33 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 91a37d897c137e..3608e88b5de2bd 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -58,6 +58,7 @@ def __init__(self, optimizer): # reduced grads to param name self._reduced_grads_to_param = {} self._shard = Shard() + self._verbose = False # use sharding as outer parallelism (e.g. inner:Megatron & outer sharding) self.mp_degree = 1 @@ -84,7 +85,6 @@ def minimize_impl(self, no_grad_set=None): # TODO: (JZ-LIANG) support multiple comm in future # self._nrings = self.user_defined_strategy.nccl_comm_num - logging.info("########### into sharding minimize: wait at end") self._nrings_sharding = 1 self._nrings_dp = 1 self.hybrid_dp = self.user_defined_strategy.sharding_configs[ @@ -120,7 +120,6 @@ def minimize_impl(self, "self.inner_opt of ShardingOptimizer should not be None.") optimize_ops, params_grads = self.inner_opt.minimize( loss, startup_program, parameter_list, no_grad_set) - logging.info("########### after inner opt") if startup_program is None: startup_program = default_startup_program() @@ -131,17 +130,14 @@ def minimize_impl(self, # step1: set_up self._set_up(params_grads) - logging.info("########### after inner _set_up") # step2: split_program self._split_program(main_block) - logging.info("########### after inner _split_program") # step3: add broadcast and reduce ops self._add_broadcast_allreduce(main_block) main_block._sync_with_cpp() startup_block._sync_with_cpp() - logging.info("########### after inner _add_broadcast_allreduce") # step4: insert reduce_sum for grad grad_scale_coeff = self.role_maker._worker_num() @@ -155,7 +151,6 @@ def minimize_impl(self, self._prune_startup_program(startup_block) if self.hybrid_dp: self._initialization_broadcast(startup_program) - logging.info("########### after inner prune") with open("main_before_gm_%d" % self.role_maker._worker_index(), 'w') as f: @@ -173,9 +168,7 @@ def minimize_impl(self, # check_broadcast(main_block) # check_allreduce_sum(main_block, self._shard, self.sharding_ring_id, # self.dp_ring_id) - logging.info("########### after inner check") self._wait() - logging.info("########### after inner wait") return optimize_ops, params_grads @@ -263,11 +256,6 @@ def _split_program(self, block): ".cast_fp16@GRAD")] if input_name in self._backward_remain_anchors: - logging.info("backward segment:") - logging.info("op [{}] input [{}] output [{}]". - format(op.desc.type(), - op.desc.input_arg_names(), - op.desc.output_arg_names())) segment = self.collect_segment(segment, op_idx, block) assert input_name not in self._forward_remain_anchors, "segment anchor [{}] met twice !".format( @@ -277,11 +265,6 @@ def _split_program(self, block): elif int(op.attr('op_role')) == int(OpRole.Forward): for output_name in op.desc.output_arg_names(): if output_name in self._forward_remain_anchors: - logging.info("forward segment:") - logging.info("op [{}] input [{}] output [{}]". - format(op.desc.type(), - op.desc.input_arg_names(), - op.desc.output_arg_names())) segment = self.collect_segment(segment, op_idx, block) self._forward_remain_anchors.remove(output_name) @@ -355,18 +338,21 @@ def _split_program(self, block): self._backward_remain_anchors) == 0, "remain anchors {}".format( self._backward_remain_anchors) - for varname in sorted( - var2broadcast_time, key=var2broadcast_time.get, reverse=True): - logging.info("Sharding broadcast: [{}] times [{}]".format( - var2broadcast_time[varname], varname)) - for idx_ in range(len(self._segments)): - logging.info("segment [{}] :".format(idx_)) - logging.info("start op: [{}] [{}]".format(block.ops[self._segments[ - idx_]._start_idx].desc.type(), block.ops[self._segments[ - idx_]._start_idx].desc.input_arg_names())) - logging.info("end op: [{}] [{}]".format(block.ops[self._segments[ - idx_]._end_idx].desc.type(), block.ops[self._segments[ - idx_]._end_idx].desc.input_arg_names())) + if self._verbose: + for varname in sorted( + var2broadcast_time, key=var2broadcast_time.get, + reverse=True): + logging.info("Sharding broadcast: [{}] times [{}]".format( + var2broadcast_time[varname], varname)) + for idx_ in range(len(self._segments)): + logging.info("segment [{}] :".format(idx_)) + logging.info("start op: [{}] [{}]".format(block.ops[ + self._segments[idx_]._start_idx].desc.type(), block.ops[ + self._segments[idx_]._start_idx].desc.input_arg_names( + ))) + logging.info("end op: [{}] [{}]".format(block.ops[ + self._segments[idx_]._end_idx].desc.type(), block.ops[ + self._segments[idx_]._end_idx].desc.input_arg_names())) return def _prune_main_program(self, block): @@ -959,9 +945,6 @@ def _true_apply_gradient(self): ): var_ = self._main_program.global_block().var(output_name) if not var_.persistable: - logging.info( - "gradient merge move non persist var from block0: ", - output_name) # move name_ = var_.name shape_ = var_.shape From cf5b1c91fce947febc23e296520da30689313f98 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Mon, 29 Mar 2021 22:52:08 +0800 Subject: [PATCH 18/24] sharding: add sync in startup prog, uniform parallelism switch --- .../sharding/gradient_clip_helper.py | 17 +- .../fleet/meta_optimizers/sharding/utils.py | 22 ++ .../meta_optimizers/sharding_optimizer.py | 323 +++++++++++------- .../tests/unittests/dist_sharding_save.py | 3 +- .../unittests/fleet_meta_optimizer_base.py | 3 +- .../test_fleet_sharding_meta_optimizer.py | 4 +- 6 files changed, 245 insertions(+), 127 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index 834ff74bb212b7..5082bc33830198 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -23,7 +23,7 @@ def _is_gradient_clip_op(self, op): return op.desc.has_attr("op_namescope") \ and op.desc.attr("op_namescope").startswith("/gradient_clip") - def prune_gradient_clip(self, block, shard): + def prune_gradient_clip(self, block, shard, pure_dp_degree=1): """ prune gradient_clip related ops for params that not belong to cur shard prune: square, reduce_sum, elementwise_mul @@ -86,6 +86,21 @@ def prune_gradient_clip(self, block, shard): OP_ROLE_KEY: OpRole.Optimize, }) + # global norm should only be sum within each model parallelism word size when use global group + if pure_dp_degree > 1: + block._insert_op_without_sync( + idx + 2, + type='scale', + inputs={'X': sum_res}, + outputs={'Out': sum_res}, + attrs={ + 'scale': 1.0 / float(pure_dp_degree), + 'op_namescope': "/gradient_clip_model_parallelism", + 'bias': 0.0, + 'bias_after_scale': False, + OP_ROLE_KEY: OpRole.Optimize + }) + # the grad sum here should take the all and only param in the current shard to_check_param = set(reversed_x_paramname) should_check_param = set(shard.global_params).intersection(set( diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index 26944d9798aa6b..bf1e3186c44164 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -556,3 +556,25 @@ def get_grad_device(grad_name, shard): base_name) return shard.global_param2device[base_name] + + +def append_naive_sync(block, sync_var, ring_id): + # NOTE (JZ-LIANG) update this to use barrier sync for more elegent logic + # sync within global + block.append_op( + type="fill_constant", + outputs={"Out": sync_var}, + attrs={ + "shape": sync_var.shape, + "dtype": sync_var.dtype, + "value": int(1), + }) + block.append_op( + type='c_allreduce_sum', + inputs={'X': sync_var}, + outputs={'Out': sync_var}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward + }) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 3608e88b5de2bd..f7691f15e5e548 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -87,14 +87,31 @@ def minimize_impl(self, # self._nrings = self.user_defined_strategy.nccl_comm_num self._nrings_sharding = 1 self._nrings_dp = 1 - self.hybrid_dp = self.user_defined_strategy.sharding_configs[ - "hybrid_dp"] + + # parallelism + self.sharding_degree = int(self.user_defined_strategy.sharding_configs[ + "sharding_degree"]) + assert self.sharding_degree > 1, "sharding degree must be larger than zero" self.mp_degree = int(self.user_defined_strategy.sharding_configs[ "mp_degree"]) + self.hybrid_dp = self.user_defined_strategy.sharding_configs[ + "hybrid_dp"] + + self.pp_degree = 1 + + # dp here is the pure dp as the outest parallelism + self.dp_degree = int(self.role_maker._worker_num() // self.mp_degree // + self.sharding_degree) + assert self.role_maker._worker_num( + ) == self.dp_degree * self.mp_degree * self.sharding_degree * self.pp_degree + if self.hybrid_dp: + assert self.dp_degree > 1, "hybrid dp is on, but dp degree is [{}]".format( + self.dp_degree) + + # segment self._sharding_segment_strategy = str( self.user_defined_strategy.sharding_configs[ "sharding_segment_strategy"]) - if self._sharding_segment_strategy == "segment_broadcast_MB": self._broadcast_MB = self.user_defined_strategy.sharding_configs[ "segment_broadcast_MB"] @@ -110,6 +127,8 @@ def minimize_impl(self, raise NotImplementedError( "the sharding segment strategy [{}] is not implemented".format( str(self._sharding_segment_strategy))) + + # gradient merge self._gradient_merge_acc_step = int( self.user_defined_strategy.sharding_configs[ "gradient_merge_acc_step"]) @@ -128,8 +147,11 @@ def minimize_impl(self, self._main_program = main_block.program self._startup_program = startup_program - # step1: set_up - self._set_up(params_grads) + # step0: _init_comm + self._init_comm() + + # step1: _build_shard + self._build_shard(params_grads) # step2: split_program self._split_program(main_block) @@ -139,11 +161,12 @@ def minimize_impl(self, main_block._sync_with_cpp() startup_block._sync_with_cpp() - # step4: insert reduce_sum for grad - grad_scale_coeff = self.role_maker._worker_num() - if self.mp_degree: - grad_scale_coeff = grad_scale_coeff / self.mp_degree - insert_scale_loss_grad_ops(main_block, scale=1.0 / grad_scale_coeff) + # step4: scale the loss by the num of dp degree + # sharding is also a senario of dp + scale_ = self.dp_degree * self.sharding_degree + if scale_ > 1: + insert_scale_loss_grad_ops(main_block, scale=1.0 / scale_) + main_block._sync_with_cpp() # step5: remove unneeded ops and vars from block @@ -152,18 +175,10 @@ def minimize_impl(self, if self.hybrid_dp: self._initialization_broadcast(startup_program) - with open("main_before_gm_%d" % self.role_maker._worker_index(), - 'w') as f: - f.writelines(str(self._main_program)) - # step6: optional gradient merge if self._gradient_merge_acc_step > 1: self._sharding_gradient_merge(main_block) - with open("main_after_gm_%d" % self.role_maker._worker_index(), - 'w') as f: - f.writelines(str(self._main_program)) - # # check op dependecy # check_broadcast(main_block) # check_allreduce_sum(main_block, self._shard, self.sharding_ring_id, @@ -172,37 +187,52 @@ def minimize_impl(self, return optimize_ops, params_grads - def _set_up(self, params_grads): - # step 1: initialize nccl - self.global_word_size = self.role_maker._worker_num() - self.global_rank = self.role_maker._worker_index() - self.endpoints = self.role_maker._get_trainer_endpoints() - self.current_endpoint = self.endpoints[self.global_rank] - self._collective_helper = CollectiveHelper(self.role_maker, - self._nrings_sharding) + def _init_comm(self): # config sharding & dp groups - self._init_comm() - # sharding + self._build_group() + + startup_block = self._startup_program.global_block() + self.startup_prog_sync_var = startup_block.create_var( + name="startup_prog_sync_var", + shape=[1], + dtype=core.VarDesc.VarType.INT32, + persistable=False) + + # global self._collective_helper._init_communicator( - self._startup_program, self.current_endpoint, - self.sharding_group_endpoints, self.sharding_rank, - self.sharding_ring_id, False) + self._startup_program, self.current_endpoint, self.global_endpoints, + self.global_rank, self.global_ring_id, False) + append_naive_sync(startup_block, self.startup_prog_sync_var, + self.global_ring_id) - # inner & outer model parallelism + # mp if self.mp_degree > 1: self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, - self.mp_group_endpoints, self.mp_rank, self.mp_group_id, False) + self.mp_group_endpoints, self.mp_rank, self.mp_ring_id, False) + append_naive_sync(startup_block, self.startup_prog_sync_var, + self.global_ring_id) + + # sharding + if self.sharding_degree > 1: + self._collective_helper._init_communicator( + self._startup_program, self.current_endpoint, + self.sharding_group_endpoints, self.sharding_rank, + self.sharding_ring_id, False) + append_naive_sync(startup_block, self.startup_prog_sync_var, + self.global_ring_id) # dp - if self.hybrid_dp: + if self.dp_degree > 1: self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, False) + append_naive_sync(startup_block, self.startup_prog_sync_var, + self.global_ring_id) - startup_block = self._startup_program.global_block() startup_block._sync_with_cpp() + def _build_shard(self, params_grads): # step 2: split params self._params = set([x[0].name for x in params_grads]) self._shard.setup(params_grads, self.sharding_rank, @@ -269,7 +299,7 @@ def _split_program(self, block): block) self._forward_remain_anchors.remove(output_name) -# find broadcast vars + # find broadcast vars for input_name in op.desc.input_arg_names(): if input_name not in self._broadcast_vars: continue @@ -372,11 +402,20 @@ def _prune_main_program(self, block): # group. and each Data Parallelism group should have its own sync of FoundInfinite Model_Paramllelism_ring_id = self.sharding_ring_id if self.mp_degree > 1: - Model_Paramllelism_ring_id = self.mp_group_id + Model_Paramllelism_ring_id = self.mp_ring_id + # amp could use global group for sync FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param, - Model_Paramllelism_ring_id) - gradientclip_helper = GradientClipHelper(Model_Paramllelism_ring_id) - gradientclip_helper.prune_gradient_clip(block, self._shard) + self.global_ring_id) + # clipbyglobalnorm should only use the Model paramllelism group (mp-sharding-pp) + if self.mp_degree * self.pp_degree == 1: + # separate the sharding-hybrid senario to keep the accuracy + gradientclip_helper = GradientClipHelper(self.sharding_ring_id) + gradientclip_helper.prune_gradient_clip( + block, self._shard, pure_dp_degree=1) + else: + gradientclip_helper = GradientClipHelper(self.global_ring_id) + gradientclip_helper.prune_gradient_clip( + block, self._shard, pure_dp_degree=self.dp_degree) # build prog deps reduced_grads = [] @@ -632,102 +671,138 @@ def _prune_startup_program(self, block): block._remove_var(var_name, sync=False) block._sync_with_cpp() - def _init_comm(self): - - if self.hybrid_dp: - assert self.mp_degree <= 1, "hybrid dp is conflict when using sharding with megatron." - self.sharding_degree = self.user_defined_strategy.sharding_configs[ - "sharding_degree"] - self.sharding_ring_id = 0 - self.sharding_rank = self.global_rank % self.sharding_degree - - self.dp_degree = self.global_word_size // self.sharding_degree - self.dp_rank = self.global_rank // self.sharding_degree - self.dp_ring_id = self.sharding_rank + 1 - - self.sharding_group_endpoints = [ - ep for idx, ep in enumerate(self.endpoints) - if (idx // self.sharding_degree) == self.dp_rank - ] - self.dp_group_endpoints = [ - ep for idx, ep in enumerate(self.endpoints) - if (idx % self.sharding_degree) == self.sharding_rank + def _build_group(self): + """ + pre-assign ring ids + mp: 0 + sharding: 1 + pure-dp: 2 + global: 3 + pp: >= 20 + if one parallelism is not enable: -1 + and only support parallelism hierarchy: mp --> sharding --> pp --> dp + """ + # step 1: initialize nccl + self.global_word_size = self.role_maker._worker_num() + self.global_rank = self.role_maker._worker_index() + self.global_endpoints = self.role_maker._get_trainer_endpoints() + self.current_endpoint = self.global_endpoints[self.global_rank] + self._collective_helper = CollectiveHelper( + self.role_maker, nrings=self._nrings_sharding) + assert self.global_word_size % self.mp_degree == 0, \ + "global_word_size: {} should be divisible to the mp_degree: {}".format(self.global_word_size, self.mp_degree) + assert self.global_word_size % self.sharding_degree == 0, \ + "global_word_size: {} should be divisible to the sharding_degree: {}".format(self.global_word_size, self.sharding_degree) + assert self.global_word_size % self.pp_degree == 0, \ + "global_word_size: {} should be divisible to the pp_degree: {}".format(self.global_word_size, self.pp_degree) + assert self.global_word_size % self.dp_degree == 0, \ + "global_word_size: {} should be divisible to the dp_degree: {}".format(self.global_word_size, self.dp_degree) + + # mp group + if self.mp_degree > 1: + self.mp_ring_id = 0 + self.mp_rank = self.global_rank % self.mp_degree + self.mp_group_id = self.global_rank // self.mp_degree + self.mp_group_endpoints = [ + ep for idx, ep in enumerate(self.global_endpoints) + if idx // self.mp_degree == self.mp_group_id ] - assert self.global_word_size > self.sharding_degree, \ - "global_word_size: {} should be larger than sharding_degree: {}".format(self.global_word_size, self.sharding_degree) - assert self.global_word_size % self.sharding_degree == 0, \ - "global_word_size: {} should be divisible to the sharding_degree: {}".format(self.global_word_size, self.sharding_degree) - assert self.dp_degree * self.sharding_degree == self.global_word_size, \ - "global_word_size: {} should be equal to the product of sharding_degree: {} and dp_degree: {}".format( - self.global_word_size, - self.sharding_degree, - self.dp_degree) - - # sharding parallelism is the only model parallelism in the current setting - self.mp_group_id = self.sharding_ring_id - self.mp_rank = self.sharding_rank - self.mp_degree = self.sharding_degree - self.mp_group_endpoints = self.sharding_group_endpoints[:] - - logging.info("Using Sharing&DP mode !") + assert self.current_endpoint in self.mp_group_endpoints + assert len( + self.mp_group_endpoints + ) == self.mp_degree, "num of mp worker in group is [{}], but mp group size is [{}]".format( + len(self.mp_group_endpoints), self.mp_degree) else: + self.mp_degree = 1 + self.mp_ring_id = -1 + self.mp_rank = -1 + self.mp_group_id = -1 + self.mp_group_endpoints = [] + + # sharding + if self.sharding_degree > 1: + self.sharding_ring_id = 1 + self.sharding_rank = (self.global_rank // + self.mp_degree) % self.sharding_degree + self.sharding_group_id = self.global_rank // (self.mp_degree * + self.sharding_degree) + # mp + sharding + ... if self.mp_degree > 1: - self.sharding_ring_id = 1 - assert self.global_word_size > self.mp_degree, \ - "global_word_size: {} should be larger than mp_degree: {}".format(self.global_word_size, self.mp_degree) - assert self.global_word_size % self.mp_degree == 0, \ - "global_word_size: {} should be divisible to the mp_degree: {}".format(self.global_word_size, self.mp_degree) - self.sharding_rank = self.global_rank // self.mp_degree - self.sharding_degree = self.role_maker._worker_num( - ) // self.mp_degree - _offset = self.global_rank % self.mp_degree self.sharding_group_endpoints = [ - ep for idx, ep in enumerate(self.endpoints) - if idx % self.mp_degree == _offset + ep for idx, ep in enumerate(self.global_endpoints) + if (idx // (self.mp_degree * self.sharding_degree)) == self. + sharding_group_id and idx % self.mp_degree == self.mp_rank ] - - # the current entire model parallelism group is the combination of innert & sharding parallelism - self.mp_group_id = 2 - self.mp_rank = self.global_rank - self.mp_degree = self.role_maker._worker_num() - self.mp_group_endpoints = self.endpoints[:] - logging.info("Using Sharing as Outer parallelism mode !") - + # sharding + ... else: - self.sharding_ring_id = 0 - self.sharding_rank = self.global_rank - self.sharding_degree = self.role_maker._worker_num() - self.sharding_group_endpoints = self.endpoints - - # sharding parallelism is the only model parallelism in the current setting - self.mp_group_id = self.sharding_ring_id - self.mp_rank = self.sharding_rank - self.mp_degree = self.sharding_degree - self.mp_group_endpoints = self.sharding_group_endpoints[:] - - logging.info("Using Sharing alone mode !") - + self.sharding_group_endpoints = [ + ep for idx, ep in enumerate(self.global_endpoints) + if (idx // (self.mp_degree * self.sharding_degree) + ) == self.sharding_group_id + ] + assert self.current_endpoint in self.sharding_group_endpoints + else: + self.sharding_degree = 1 + self.sharding_ring_id = -1 + self.sharding_rank = -1 + self.sharding_group_id = -1 + self.sharding_group_endpoints = [] + + # outter-pure-dp group + # NOTE (JZ-LIANG) support outter-pure-dp to scale the throughput in 3D parallelism + # e.g. mp-sharding-pp-dp + # sharding-hybrid-dp as one senario of outter-pure-dp + assert self.global_word_size == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "mp_degree: [{}], sharding_degree: [{}], pp_degree: [{}], dp_degree: [{}]; BUT global nrank: [{}]".format( + self.mp_degree, self.sharding_degree, self.pp_degree, + self.dp_degree, self.global_word_size) + if self.dp_degree > 1: + self.dp_ring_id = 2 + self.dp_rank = self.global_rank // (self.sharding_degree * + self.mp_degree * self.pp_degree) + dp_first_rank_idx = self.global_rank % ( + self.sharding_degree * self.mp_degree * self.pp_degree) + dp_offset = (self.sharding_degree * self.mp_degree * self.pp_degree) + self.dp_group_endpoints = [] + for i in range(self.dp_degree): + self.dp_group_endpoints.append(self.global_endpoints[ + dp_first_rank_idx + dp_offset * i]) + assert self.current_endpoint in self.dp_group_endpoints + logging.info("Hybrid DP mode turn on !") + else: self.dp_ring_id = -1 self.dp_rank = -1 - self.dp_degree = None - self.dp_group_endpoints = None + self.dp_group_endpoints = [] + + # global group + self.global_ring_id = 3 logging.info("global word size: {}".format(self.global_word_size)) logging.info("global rank: {}".format(self.global_rank)) - logging.info("sharding degree: {}".format(self.sharding_degree)) + logging.info("global endpoints: {}".format(self.global_endpoints)) + logging.info("global ring id: {}".format(self.global_ring_id)) + logging.info("#####" * 6) + + logging.info("mp group size: {}".format(self.mp_degree)) + logging.info("mp rank: {}".format(self.mp_rank)) + logging.info("mp group id: {}".format(self.mp_group_id)) + logging.info("mp group endpoints: {}".format(self.mp_group_endpoints)) + logging.info("mp ring id: {}".format(self.mp_ring_id)) + logging.info("#####" * 6) + + logging.info("sharding group size: {}".format(self.sharding_degree)) logging.info("sharding rank: {}".format(self.sharding_rank)) - logging.info("current model parallelism degree: {}".format( - self.mp_degree)) - logging.info("current model parallelism rank: {}".format(self.mp_rank)) - logging.info("dp group size: {}".format(self.dp_degree)) - logging.info("dp rank: {}".format(self.dp_rank)) - logging.info("current endpoint: {}".format(self.current_endpoint)) - logging.info("global word endpoints: {}".format(self.endpoints)) + logging.info("sharding group id: {}".format(self.sharding_group_id)) logging.info("sharding group endpoints: {}".format( self.sharding_group_endpoints)) - logging.info("current model parallelism group endpoints: {}".format( - self.mp_group_endpoints)) - logging.info("dp group endpoints: {}".format(self.dp_group_endpoints)) + logging.info("sharding ring id: {}".format(self.sharding_ring_id)) + logging.info("#####" * 6) + + logging.info("outter pure dp group size: {}".format(self.dp_degree)) + logging.info("outter pure dp rank: {}".format(self.dp_rank)) + logging.info("outter pure dp group endpoints: {}".format( + self.dp_group_endpoints)) + logging.info("outter pure dp ring id: {}".format(self.dp_ring_id)) + logging.info("#####" * 6) return @@ -756,6 +831,10 @@ def _initialization_broadcast(self, startup_prog): attrs={'ring_id': self.dp_ring_id, OP_ROLE_KEY: OpRole.Forward}) + # sync within global group + append_naive_sync(block, self.startup_prog_sync_var, + self.global_ring_id) + # sharding gradient merge def create_persistable_gradients_and_insert_merge_ops( self, main_block, startup_block, insert_idx, grad_names, shard): diff --git a/python/paddle/fluid/tests/unittests/dist_sharding_save.py b/python/paddle/fluid/tests/unittests/dist_sharding_save.py index d686c507e3b0c8..676b15c0d93e76 100755 --- a/python/paddle/fluid/tests/unittests/dist_sharding_save.py +++ b/python/paddle/fluid/tests/unittests/dist_sharding_save.py @@ -61,7 +61,8 @@ def runtime_main(): strategy.sharding = True strategy.sharding_configs = { "sharding_segment_strategy": "segment_broadcast_MB", - "segment_broadcast_MB": 0.2 + "segment_broadcast_MB": 0.2, + "sharding_degree": 2, } optimizer = paddle.fluid.optimizer.Momentum( diff --git a/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py b/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py index fb2aaef7b3a668..549975f5d3f0f4 100755 --- a/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py +++ b/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py @@ -148,7 +148,8 @@ def set_strategy(self, strategy, name): strategy.sharding = True strategy.sharding_configs = { "sharding_segment_strategy": "segment_broadcast_MB", - "segment_broadcast_MB": 0.2 + "segment_broadcast_MB": 0.2, + "sharding_degree": 2, } elif name == "recompute-offload": strategy.recompute = True diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index fc5de320db2872..f8815259c09eb5 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -279,8 +279,8 @@ def test_sharding_clone_for_test(self): self.optimizer(avg_cost, strategy, train_prog, startup_prog) sharding.utils.comm_analyse(train_prog) test_prog = train_prog.clone(for_test=True) - # assume sharding_ring_id = 0 - sharding.utils.add_sync_comm(test_prog, 0) + # assume sharding_ring_id = 1 + sharding.utils.add_sync_comm(test_prog, 1) ops = [op.type for op in test_prog.global_block().ops] self.assertEqual(ops, [ From ceae74b41a63e66834a6a79df0d96491930dd31d Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 30 Mar 2021 17:22:54 +0800 Subject: [PATCH 19/24] sharding: update unitest --- .../meta_optimizers/sharding_optimizer.py | 3 - .../test_fleet_sharding_meta_optimizer.py | 195 ++++++++++++++++++ 2 files changed, 195 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index f7691f15e5e548..4dc5ed24bd31d0 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -400,9 +400,6 @@ def _prune_main_program(self, block): weightdecay_helper.prune_weight_decay(block, self._shard) # NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism # group. and each Data Parallelism group should have its own sync of FoundInfinite - Model_Paramllelism_ring_id = self.sharding_ring_id - if self.mp_degree > 1: - Model_Paramllelism_ring_id = self.mp_ring_id # amp could use global group for sync FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param, self.global_ring_id) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index f8815259c09eb5..4d6744f2b6fe48 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -292,5 +292,200 @@ def test_sharding_clone_for_test(self): ]) +class TestFleetMetaOptimizer(TestFleetMetaOptimizer): + def setUp(self): + os.environ["PADDLE_TRAINER_ID"] = "3" + os.environ[ + "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002,127.0.0.1:36003,127.0.0.1:36004" + + def test_sharding_with_mp(self): + # NOTE(JZ-LIANG) MP parallelism need user to build model with MP API + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, _ = self.net(train_prog, startup_prog) + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.sharding = True + strategy.sharding_configs = { + "sharding_segment_strategy": "segment_broadcast_MB", + "segment_broadcast_MB": 0.2, + "segment_anchors": None, + "sharding_degree": 2, + "hybrid_dp": False, + "gradient_merge_acc_step": 1, + "mp_degree": 2 + } + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + startup_prog_ops = startup_prog.global_block().ops + main_prog_ops = train_prog.global_block().ops + + # should has ring id for MP + created_ring_ids = [ + op.desc.attr("ring_id") for op in startup_prog_ops + if op.type == "c_comm_init" + ] + self.assertIn(0, created_ring_ids) + + # check correctness of MP group + sharding_group_waiting_port = None + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "nccl_id_1": + sharding_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003']) + + # check correctness of sharding group + sharding_group_waiting_port = None + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "nccl_id_2": + dp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) + + def test_sharding_hybrid_dp(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, _ = self.net(train_prog, startup_prog) + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.sharding = True + strategy.sharding_configs = { + "sharding_segment_strategy": "segment_broadcast_MB", + "segment_broadcast_MB": 0.2, + "segment_anchors": None, + "sharding_degree": 2, + "hybrid_dp": True, + "gradient_merge_acc_step": 1, + "mp_degree": 1 + } + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + startup_prog_ops = startup_prog.global_block().ops + main_prog_ops = train_prog.global_block().ops + + # check ring id for outter dp + created_ring_ids = [ + op.desc.attr("ring_id") for op in startup_prog_ops + if op.type == "c_comm_init" + ] + self.assertIn(2, created_ring_ids) + + # check correctness of sharding group + sharding_group_waiting_port = None + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "nccl_id_1": + sharding_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003']) + + # check correctness of dp group + sharding_group_waiting_port = None + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "nccl_id_2": + dp_group_waiting_ports = op.desc.attr("other_endpoints") + self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) + + # check loss scale for sharding hybrid dp + scale_ = -1 + for op in main_prog_ops: + if op.type == "scale": + scale_ = float(op.desc.attr("scale")) + self.assertEqual(scale_, 0.25) + + # check program (allreudce) + ops = [op.type for op in main_prog_ops] + self.assertEqual(ops, [ + 'fill_constant', 'fill_constant', 'fill_constant', + 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', + 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', + 'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', + 'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2', + 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', + 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', + 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', + 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', + 'c_sync_comm_stream', 'momentum', 'momentum', 'momentum' + ]) + + def test_sharding_hybrid_dp_gm(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, _ = self.net(train_prog, startup_prog) + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.sharding = True + strategy.sharding_configs = { + "sharding_segment_strategy": "segment_broadcast_MB", + "segment_broadcast_MB": 0.2, + "segment_anchors": None, + "sharding_degree": 2, + "hybrid_dp": True, + "gradient_merge_acc_step": 4, + "mp_degree": 1 + } + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + startup_prog_ops = startup_prog.global_block().ops + main_prog_ops = train_prog.global_block().ops + + # check ring id for outter dp + created_ring_ids = [ + op.desc.attr("ring_id") for op in startup_prog_ops + if op.type == "c_comm_init" + ] + self.assertIn(2, created_ring_ids) + + # check correctness of sharding group + sharding_group_waiting_port = None + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "nccl_id_1": + sharding_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003']) + + # check correctness of dp group + sharding_group_waiting_port = None + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "nccl_id_2": + dp_group_waiting_ports = op.desc.attr("other_endpoints") + self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) + + # check program + fw_bw_ops = [op.type for op in train_prog.blocks[0].ops] + opt_ops = [op.type for op in train_prog.blocks[2].ops] + self.assertEqual(fw_bw_ops, [ + 'fill_constant', 'fill_constant', 'fill_constant', + 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', + 'c_sync_comm_stream', 'mul', 'elementwise_add', 'tanh', 'mul', + 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax', + 'cross_entropy2', 'mean', 'fill_constant', 'scale', 'mean_grad', + 'cross_entropy_grad2', 'softmax_grad', 'elementwise_add_grad', + 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', + 'tanh_grad', 'elementwise_add_grad', 'mul_grad', + 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', + 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', + 'c_sync_comm_stream', 'elementwise_add', 'elementwise_add', + 'elementwise_add', 'increment', 'elementwise_mod', 'equal', + 'conditional_block' + ]) + self.assertEqual(opt_ops, [ + 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'scale', + 'scale', 'scale', 'momentum', 'momentum', 'momentum', + 'fill_constant', 'fill_constant', 'fill_constant' + ]) + + # # check loss scale for gradient merge + scale_ = -1 + for op in train_prog.blocks[2].ops: + if op.type == "scale": + scale_ = float(op.desc.attr("scale")) + self.assertEqual(scale_, 0.25) + + if __name__ == "__main__": unittest.main() From dde7d243300cb83d5818004c44c0b631fea684f1 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 31 Mar 2021 14:47:01 +0800 Subject: [PATCH 20/24] sharding: add more comments --- .../fleet/meta_optimizers/amp_optimizer.py | 3 +-- .../fleet/meta_optimizers/sharding/fp16_helper.py | 6 ------ .../fleet/meta_optimizers/sharding/utils.py | 7 ++++--- python/paddle/fluid/backward.py | 13 ++++++------- 4 files changed, 11 insertions(+), 18 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py index 8e4ddedadf0682..02505e01197dc6 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py @@ -58,9 +58,8 @@ def _init_wrapped_opt(self): # computation by split the check_finite_and_unscale op. is_distributed = self.role_maker._worker_num() > 1 if self.user_defined_strategy.sharding: - # if self.user_defined_strategy.sharding or self.user_defined_strategy.model_parallel: # FIXME(wangxi). sharding failed when split check_finite_and_unscale - # FIXME(JZ-LIANG). To support Sharding-Megatron-AMP, Megatron should follow Sharding's behavior + # FIXME(JZ-LIANG). To support Sharding-Megatron-AMP, Megatron should follow Sharding's behavior that to disable is_distributed. is_distributed = False self.wrapped_opt._set_distributed(is_distributed) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py index e946ed5fb3fbe6..cf399f66946bd7 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py @@ -148,8 +148,6 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id): OP_ROLE_KEY: OpRole.Optimize }) # this allreduce communication should not overlap with calc - # insert_sync_calc_op(block, update_loss_scaling_op_idx + 1, - # [inf_var_int32]) block._insert_op_without_sync( update_loss_scaling_op_idx + 1, type='c_allreduce_max', @@ -160,10 +158,6 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id): 'use_calc_stream': True, OP_ROLE_KEY: OpRole.Optimize }) - - # comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3, - # ring_id, [inf_var_int32]) - block._insert_op_without_sync( update_loss_scaling_op_idx + 2, type='cast', diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index bf1e3186c44164..8b111026bdb916 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -103,6 +103,7 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): idx_gradient_clip_allreduce = -1 for idx, op in enumerate(block.ops): + # sharding use both allreduce and reduce to sync grad if op.type == "c_allreduce_sum" or op.type == "c_reduce_sum": if op.all_attrs()["use_calc_stream"] == False: ring_id = op.desc.attr("ring_id") @@ -136,7 +137,7 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): if var_name in dp_grads_status and dp_grads_status[ var_name] == 0: dp_grads_status[var_name] = 1 - + # check sharding allreduce and reduce but skip megatron allreduce elif op.type == "c_allreduce_sum" or op.type == "c_reduce_sum": if op.all_attrs()["use_calc_stream"] == False: var_name = op.desc.input_arg_names()[0] @@ -503,7 +504,7 @@ def save_persistables(exe, dirname, main_program, filename=None): """ def is_opt_vars(var): - # NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer + # NOTE(JZ-LIANG): The checks should be updated when add new compatible optimizer # now only Momentum and adam are compatible with sharding checks = [ "_moment1_0", "_moment2_0", "_beta1_pow_acc_0", "_beta2_pow_acc_0", @@ -515,7 +516,7 @@ def is_opt_vars(var): return False def is_gradient_merge_vars(var): - # NOTE(liangjianzhong): to revise save/load logic in framework instead of write this naive rule + # NOTE(JZ-LIANG): to revise save/load logic in framework instead of write this naive rule return var.name.endswith("@GradiantMerge") diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index bd20b6d31a4f70..8fc972a6fd7954 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -890,15 +890,14 @@ def _append_backward_ops_with_checkpoints_( var_name_dict[name] = name + var_suffix # we should create the rename var in subprog, otherwise its VarType will be BOOL + ref_var = block.program.global_block().var(name) block.create_var( name=var_name_dict[name], - shape=block.program.global_block().var(name).shape, - dtype=block.program.global_block().var(name).dtype, - type=block.program.global_block().var(name).type, - persistable=block.program.global_block().var( - name).persistable, - stop_gradient=block.program.global_block().var(name) - .stop_gradient) + shape=ref_var.shape, + dtype=ref_var.dtype, + type=ref_var.type, + persistable=ref_var.var(name).persistable, + stop_gradient=ref_var.var(name).stop_gradient) # 3.a. add ops in current recompute_segment as forward recomputation ops buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block, From 60be6ece30a993791cf0b3feec60dff371e48c58 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 31 Mar 2021 15:42:19 +0800 Subject: [PATCH 21/24] recompute: fixed bug in create vars --- .../distributed/fleet/meta_optimizers/sharding_optimizer.py | 1 + python/paddle/fluid/backward.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 4dc5ed24bd31d0..77349355907f29 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -180,6 +180,7 @@ def minimize_impl(self, self._sharding_gradient_merge(main_block) # # check op dependecy + # FIXME (JZ-LIANG) enable checking in future. # check_broadcast(main_block) # check_allreduce_sum(main_block, self._shard, self.sharding_ring_id, # self.dp_ring_id) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 8fc972a6fd7954..b3a1834d49d3b3 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -896,8 +896,8 @@ def _append_backward_ops_with_checkpoints_( shape=ref_var.shape, dtype=ref_var.dtype, type=ref_var.type, - persistable=ref_var.var(name).persistable, - stop_gradient=ref_var.var(name).stop_gradient) + persistable=ref_var.persistable, + stop_gradient=ref_var.stop_gradient) # 3.a. add ops in current recompute_segment as forward recomputation ops buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block, From b414ddf49d3d75c905052b3e881a5b8833c7f54e Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 2 Apr 2021 00:11:20 +0800 Subject: [PATCH 22/24] sharding temp to check ci bug --- .../paddle/fluid/tests/unittests/test_dist_base.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) mode change 100644 => 100755 python/paddle/fluid/tests/unittests/test_dist_base.py diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py old mode 100644 new mode 100755 index d73698e7e024a8..4b8f3947db0b75 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -1107,6 +1107,18 @@ def _run_cluster_nccl2(self, model, envs, update_method, check_error_log, if check_error_log: print("outs[0]:", outs[0]) print("outs[1]:", outs[1]) + + def print_file(outfilename): + print("######" * 5) + with open(outfilename, 'r') as outfile: + lines = outfile.readlines() + for line in lines: + print(line) + print("######" * 5) + + print_file("./_tr0_err.log") + print_file("./_tr1_err.log") + return pickle.loads(outs[0]), pickle.loads(outs[1]) def _run_pipeline(self, model, envs, check_error_log, log_name): From 620f13812df5e524c599795cc2f77ea1c43961d9 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 2 Apr 2021 01:11:24 +0800 Subject: [PATCH 23/24] sharding: revise comm _wait func --- .../fleet/meta_optimizers/sharding_optimizer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 77349355907f29..2badbf08099345 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -244,10 +244,9 @@ def _build_shard(self, params_grads): self._main_program.global_block()) def _wait(self, ): - # only the first parallelsm group that init nccl need to be wait. - endpoints = self.sharding_group_endpoints[:] - current_endpoint = endpoints[self.sharding_rank] - if self.sharding_rank == 0: + endpoints = self.global_endpoints[:] + current_endpoint = endpoints[self.global_rank] + if self.global_rank == 0: self._collective_helper._wait(current_endpoint, endpoints) def collect_segment(self, segment, op_idx, block): From 073244275963f69587ab8c47f2f16ad500dee809 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 2 Apr 2021 02:23:57 +0800 Subject: [PATCH 24/24] sharding: revise comm init --- .../meta_optimizers/sharding_optimizer.py | 41 +++++++++++++++---- .../fluid/tests/unittests/test_dist_base.py | 11 ----- 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 2badbf08099345..cf3f75740ee3dd 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -201,33 +201,56 @@ def _init_comm(self): # global self._collective_helper._init_communicator( - self._startup_program, self.current_endpoint, self.global_endpoints, - self.global_rank, self.global_ring_id, False) + self._startup_program, + self.current_endpoint, + self.global_endpoints, + self.global_rank, + self.global_ring_id, + False, + global_ring_id=self.global_ring_id, + sync=False) append_naive_sync(startup_block, self.startup_prog_sync_var, self.global_ring_id) # mp if self.mp_degree > 1: self._collective_helper._init_communicator( - self._startup_program, self.current_endpoint, - self.mp_group_endpoints, self.mp_rank, self.mp_ring_id, False) + self._startup_program, + self.current_endpoint, + self.mp_group_endpoints, + self.mp_rank, + self.mp_ring_id, + False, + global_ring_id=self.global_ring_id, + sync=False) append_naive_sync(startup_block, self.startup_prog_sync_var, self.global_ring_id) # sharding if self.sharding_degree > 1: self._collective_helper._init_communicator( - self._startup_program, self.current_endpoint, - self.sharding_group_endpoints, self.sharding_rank, - self.sharding_ring_id, False) + self._startup_program, + self.current_endpoint, + self.sharding_group_endpoints, + self.sharding_rank, + self.sharding_ring_id, + False, + global_ring_id=self.global_ring_id, + sync=False) append_naive_sync(startup_block, self.startup_prog_sync_var, self.global_ring_id) # dp if self.dp_degree > 1: self._collective_helper._init_communicator( - self._startup_program, self.current_endpoint, - self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, False) + self._startup_program, + self.current_endpoint, + self.dp_group_endpoints, + self.dp_rank, + self.dp_ring_id, + False, + global_ring_id=self.global_ring_id, + sync=False) append_naive_sync(startup_block, self.startup_prog_sync_var, self.global_ring_id) diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index a7fc813eeae5f9..37494294418f1c 100755 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -1126,17 +1126,6 @@ def _run_cluster_nccl2(self, model, envs, update_method, check_error_log, print("outs[0]:", outs[0]) print("outs[1]:", outs[1]) - def print_file(outfilename): - print("######" * 5) - with open(outfilename, 'r') as outfile: - lines = outfile.readlines() - for line in lines: - print(line) - print("######" * 5) - - print_file("./_tr0_err.log") - print_file("./_tr1_err.log") - return pickle.loads(outs[0]), pickle.loads(outs[1]) def _run_pipeline(self, model, envs, check_error_log, log_name):