Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,18 @@ def filter_parameters(self, parameter_list, hcg):
]
return parameter_list

def _get_param_grad(self, param):
if not param.trainable:
return None

if hasattr(param, "main_grad"):
assert (
param._grad_ivar() is None
), "param.grad should be None when using main_grad"
return param.main_grad

return param._grad_ivar()

def reduce_gradients(self, parameter_list, hcg):
# TODO merge grad / nrank with dp
logger.debug("sharding start gradients sync")
Expand All @@ -296,14 +308,7 @@ def reduce_gradients(self, parameter_list, hcg):
return
with framework.no_grad():
for param in parameter_list:
g_var = None
if param.trainable and (param._grad_ivar() is not None):
g_var = param._grad_ivar()
if param.trainable and hasattr(param, "main_grad"):
assert (
param._grad_ivar() is None
), "param.grad should be None when using main_grad"
g_var = param.main_grad
g_var = self._get_param_grad(param)
if g_var is not None:
reduce_op = (
ReduceOp.AVG if self.use_reduce_avg else ReduceOp.SUM
Expand All @@ -330,29 +335,39 @@ def reduce_gradients(self, parameter_list, hcg):

def _sharding_sync_parameters(self):
"""
sync parameter across sharding group
Synchronize parameter across sharding group efficiently.
"""
# TODO speed up this functional

with framework.no_grad():
# TODO detach not need (?)
# Choose appropriate parameters collection based on whether tensor fusion is enabled.
valid_rank_to_params = (
self._rank2params
if not self.tensor_fusion
else self._rank2fused
)

# Pre-compute sharding group ranks for efficiency
sharding_group_ranks = self._hcg.get_sharding_parallel_group().ranks

broadcast_tasks = []
for rank, params in valid_rank_to_params.items():
# Compute the global source rank only once per each rank's set of parameters
src_rank = sharding_group_ranks[rank]

for param in params:
task = paddle.distributed.broadcast(
param,
# the collective API need src rank to be the global rank id
# instead of the relative logic rank id within group
src=self._hcg.get_sharding_parallel_group().ranks[rank],
group=self._hcg.get_sharding_parallel_group(),
sync_op=False,
)
broadcast_tasks.append(task)
# NOTE: We should check if the parameter is trainable, because some parameters
# (e.g., freeze the parameters for training) are not trainable and should
# not be broadcasted.
g_var = self._get_param_grad(param)
if g_var is not None:
task = paddle.distributed.broadcast(
param,
src=src_rank,
group=self._hcg.get_sharding_parallel_group(),
sync_op=False,
)
broadcast_tasks.append(task)

# Wait for all async broadcast tasks to complete
for task in broadcast_tasks:
task.wait()

Expand Down