Skip to content

Commit 75cc705

Browse files
authored
dp and sharding coexist (#56096)
* dp and sharding coexist * dp
1 parent 77da910 commit 75cc705

File tree

2 files changed

+60
-22
lines changed

2 files changed

+60
-22
lines changed

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,15 @@ def clear_grad(self, set_to_zero=True):
100100
elif not hasattr(p, "main_grad"):
101101
p.clear_gradient(set_to_zero)
102102

103+
def filter_parameters(self, parameter_list, hcg):
104+
sharding_parallel_rank = hcg.get_sharding_parallel_rank()
105+
parameter_list = [
106+
param
107+
for param in parameter_list
108+
if self._param2rank[param.name] == sharding_parallel_rank
109+
]
110+
return parameter_list
111+
103112
def _partition_parameters(self):
104113
"""
105114
Partitions parameters among sharding ranks.

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -293,19 +293,17 @@ def _dygraph_clip(self, params_grads):
293293
params_grads, global_norm_var_dist, global_norm_var_not_dist
294294
)
295295

296-
def _comm_and_clip(
297-
self, params_grads, global_norm_var_dist, global_norm_var_not_dist
298-
):
299-
# sharding first
300-
sharding_flag = (
301-
self._hcg.get_sharding_parallel_world_size() > 1
302-
and self._hcg.get_data_parallel_world_size() == 1
303-
)
296+
def _global_norm(self, global_norm_var_dist, global_norm_var_not_dist):
297+
sharding_flag = self._hcg.get_sharding_parallel_world_size() > 1
298+
dp_flag = self._hcg.get_data_parallel_world_size() > 1
304299
mp_flag = self._hcg.get_model_parallel_world_size() > 1
305-
# add all reduce to get global norm of distributed params_and_grads
300+
pp_flag = self._hcg.get_pipe_parallel_world_size() > 1
301+
302+
# not g_shard_norm_align_dp, grads are sharded among sharding group
306303
if sharding_flag and not g_shard_norm_align_dp:
307304
# norm of mp distributed variable
308305
if mp_flag:
306+
# dist should reduce among sharding group and mp group、pp group latter
309307
paddle.distributed.all_reduce(
310308
global_norm_var_dist,
311309
group=self._hcg.get_sharding_parallel_group(),
@@ -315,21 +313,40 @@ def _comm_and_clip(
315313
global_norm_var_not_dist,
316314
group=self._hcg.get_sharding_parallel_group(),
317315
)
316+
318317
# norm of mp distributed variable
319318
if mp_flag:
320-
# dist should reduce among sharding group、mp group、pp group
321-
paddle.distributed.all_reduce(
322-
global_norm_var_dist,
323-
group=self._hcg.get_check_parallel_group(sharding_flag),
324-
)
319+
# the else branch would suffice, but this branch remains here for number precision backward compatibility
320+
if not (dp_flag and sharding_flag):
321+
paddle.distributed.all_reduce(
322+
global_norm_var_dist,
323+
group=self._hcg.get_check_parallel_group(sharding_flag),
324+
)
325+
else:
326+
# global_norm_var_dist should all reduce among model parallel group and pp group
327+
paddle.distributed.all_reduce(
328+
global_norm_var_dist,
329+
group=self._hcg.get_model_parallel_group(),
330+
)
331+
if pp_flag:
332+
paddle.distributed.all_reduce(
333+
global_norm_var_dist,
334+
group=self._hcg.get_pipe_parallel_group(),
335+
)
325336

326337
# add all reduce to get global norm of non-distributed params_and_grads in groups of pp
327-
if self._hcg.get_pipe_parallel_world_size() > 1:
338+
if pp_flag:
328339
paddle.distributed.all_reduce(
329340
global_norm_var_not_dist,
330341
group=self._hcg.get_pipe_parallel_group(),
331342
)
332343

344+
def _comm_and_clip(
345+
self, params_grads, global_norm_var_dist, global_norm_var_not_dist
346+
):
347+
348+
self._global_norm(global_norm_var_dist, global_norm_var_not_dist)
349+
333350
global_norm_var_fp32 = paddle.sqrt(
334351
global_norm_var_dist + global_norm_var_not_dist
335352
)
@@ -554,15 +571,21 @@ def _step(self, parameters_list):
554571
@no_grad()
555572
@framework.dygraph_only
556573
def step(self):
557-
parameters_list = obtain_optimizer_parameters_list(self._inner_opt)
574+
parameter_list = list(obtain_optimizer_parameters_list(self._inner_opt))
575+
dp_parameter_list = parameter_list
558576
if self._sharding_enable:
559577
assert isinstance(self._inner_opt, DygraphShardingOptimizer)
560-
self._inner_opt.reduce_gradients(list(parameters_list), self._hcg)
578+
self._inner_opt.reduce_gradients(parameter_list, self._hcg)
579+
# dp sync later do not need to use global parameter list
580+
if not g_shard_norm_align_dp:
581+
dp_parameter_list = self._inner_opt.filter_parameters(
582+
parameter_list, self._hcg
583+
)
561584

562585
if self._dp_enable:
563-
fused_allreduce_gradients(list(parameters_list), self._hcg)
586+
fused_allreduce_gradients(dp_parameter_list, self._hcg)
564587

565-
self._step(parameters_list)
588+
self._step(parameter_list)
566589

567590
@no_grad()
568591
def minimize(
@@ -574,14 +597,20 @@ def minimize(
574597
parameter_list = (
575598
parameters if parameters else self._inner_opt._parameter_list
576599
)
577-
600+
parameter_list = list(parameter_list)
601+
dp_parameter_list = parameter_list
578602
# Here sharding should use global parameter list
579603
if self._sharding_enable:
580604
assert isinstance(self._inner_opt, DygraphShardingOptimizer)
581-
self._inner_opt.reduce_gradients(list(parameter_list), self._hcg)
605+
self._inner_opt.reduce_gradients(parameter_list, self._hcg)
606+
# dp sync later do not need to use global parameter list
607+
if not g_shard_norm_align_dp:
608+
dp_parameter_list = self._inner_opt.filter_parameters(
609+
parameter_list, self._hcg
610+
)
582611

583612
if self._dp_enable:
584-
fused_allreduce_gradients(list(parameter_list), self._hcg)
613+
fused_allreduce_gradients(dp_parameter_list, self._hcg)
585614

586615
return self._inner_opt.minimize(
587616
loss, startup_program, parameter_list, no_grad_set

0 commit comments

Comments
 (0)