From d20c6468be308892af5d169900763707a0a13425 Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Mon, 11 Mar 2024 15:04:12 +0800 Subject: [PATCH 1/2] support trainable param for sharding stagev1 --- .../dygraph_sharding_optimizer.py | 57 ++++++++++++------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index 2b0001ddc5c8a9..4ea310a011a34a 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -287,6 +287,18 @@ def filter_parameters(self, parameter_list, hcg): ] return parameter_list + def _get_param_grad(self, param): + if not param.trainable: + return None + + if hasattr(param, "main_grad"): + assert ( + param._grad_ivar() is None + ), "param.grad should be None when using main_grad" + return param.main_grad + + return param._grad_ivar() if param._grad_ivar() is not None else None + def reduce_gradients(self, parameter_list, hcg): # TODO merge grad / nrank with dp logger.debug("sharding start gradients sync") @@ -296,14 +308,7 @@ def reduce_gradients(self, parameter_list, hcg): return with framework.no_grad(): for param in parameter_list: - g_var = None - if param.trainable and (param._grad_ivar() is not None): - g_var = param._grad_ivar() - if param.trainable and hasattr(param, "main_grad"): - assert ( - param._grad_ivar() is None - ), "param.grad should be None when using main_grad" - g_var = param.main_grad + g_var = self._get_param_grad(param) if g_var is not None: reduce_op = ( ReduceOp.AVG if self.use_reduce_avg else ReduceOp.SUM @@ -330,29 +335,39 @@ def reduce_gradients(self, parameter_list, hcg): def _sharding_sync_parameters(self): """ - sync parameter across sharding group + Synchronize parameter across sharding group efficiently. """ - # TODO speed up this functional - with framework.no_grad(): - # TODO detach not need (?) + # Choose appropriate parameters collection based on whether tensor fusion is enabled. valid_rank_to_params = ( self._rank2params if not self.tensor_fusion else self._rank2fused ) + + # Pre-compute sharding group ranks for efficiency + sharding_group_ranks = self._hcg.get_sharding_parallel_group().ranks + broadcast_tasks = [] for rank, params in valid_rank_to_params.items(): + # Compute the global source rank only once per each rank's set of parameters + src_rank = sharding_group_ranks[rank] + for param in params: - task = paddle.distributed.broadcast( - param, - # the collective API need src rank to be the global rank id - # instead of the relative logic rank id within group - src=self._hcg.get_sharding_parallel_group().ranks[rank], - group=self._hcg.get_sharding_parallel_group(), - sync_op=False, - ) - broadcast_tasks.append(task) + # NOTE: We should check if the parameter is trainable, because some parameters + # (e.g., freeze the parameters for training) are not trainable and should + # not be broadcasted. + g_var = self._get_param_grad(param) + if g_var is not None: + task = paddle.distributed.broadcast( + param, + src=src_rank, + group=self._hcg.get_sharding_parallel_group(), + sync_op=False, + ) + broadcast_tasks.append(task) + + # Wait for all async broadcast tasks to complete for task in broadcast_tasks: task.wait() From 9ffea9ef920eeb7b0d02c967fa09bb7d2346d818 Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Wed, 13 Mar 2024 15:39:55 +0800 Subject: [PATCH 2/2] fix code --- .../dygraph_optimizer/dygraph_sharding_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index 4ea310a011a34a..0ec8fa95992eae 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -297,7 +297,7 @@ def _get_param_grad(self, param): ), "param.grad should be None when using main_grad" return param.main_grad - return param._grad_ivar() if param._grad_ivar() is not None else None + return param._grad_ivar() def reduce_gradients(self, parameter_list, hcg): # TODO merge grad / nrank with dp