Skip to content

Commit 4f93c1f

Browse files
ForFisheswentaoyu
authored andcommitted
part-4 cherry from: [Distributed]Support param_group in sharding-stage1 (PaddlePaddle#56626)
* support param group in sharding * fix utest
1 parent 51a3dc2 commit 4f93c1f

File tree

4 files changed

+319
-128
lines changed

4 files changed

+319
-128
lines changed

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py

Lines changed: 111 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
######
1616
import os
17+
from collections import defaultdict
1718
from distutils.util import strtobool
1819
from functools import reduce
1920

@@ -22,6 +23,7 @@
2223
from paddle.base.dygraph import base as imperative_base
2324
from paddle.base.framework import EagerParamBase
2425
from paddle.distributed import fleet
26+
from paddle.nn import ClipGradByGlobalNorm
2527

2628
from ...utils.log_util import logger
2729
from ...utils.tensor_fusion_helper import (
@@ -62,21 +64,27 @@ class DygraphShardingOptimizer:
6264
# 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm
6365

6466
def __init__(self, optimizer, hcg):
65-
logger.info("init DygraphShardingOptimizer")
66-
# TODO(pangengzheng): support param_groups
67-
if isinstance(optimizer._parameter_list[0], dict):
68-
raise TypeError(
69-
"Do not support param_groups now, please set optimizer._parameter_list as a list of Parameter"
70-
)
7167
if not hasattr(optimizer, '_apply_optimize') or not callable(
7268
optimizer._apply_optimize
7369
):
7470
raise ValueError(
7571
"the optimzier object should have _apply_optimize function"
7672
)
77-
# the self._parameter_list holds the whole model paramters
78-
self._parameter_list = optimizer._parameter_list
79-
self._origin_parameter_list = self._parameter_list
73+
74+
self._using_param_groups = isinstance(
75+
optimizer._parameter_list[0], dict
76+
)
77+
78+
self._parameter_list = []
79+
self._param_2_group_id = {}
80+
if self._using_param_groups:
81+
for idx, param_group in enumerate(optimizer._param_groups):
82+
for param in param_group['params']:
83+
self._param_2_group_id[id(param)] = idx
84+
self._parameter_list.append(param)
85+
else:
86+
self._parameter_list = optimizer._parameter_list
87+
8088
self._inner_opt = optimizer
8189
self._hcg = hcg
8290
self._sharding_world_size = self._hcg.get_sharding_parallel_world_size()
@@ -110,49 +118,67 @@ def __init__(self, optimizer, hcg):
110118
self._rank2params = self._partition_parameters()
111119
self._param2rank = self._map_param_to_rank()
112120

113-
if not self.tensor_fusion and not self.comm_overlap:
114-
local_params = self._rank2params[self._sharding_rank]
115-
self._set_inner_opt_attr('_parameter_list', local_params)
116-
self._set_inner_opt_attr('_param_groups', local_params)
117-
else:
118-
self._tensor_fusion()
119-
120-
decay_params = [
121-
p.name for p in self._rank2decay[self._sharding_rank]
121+
if self._using_param_groups:
122+
param_groups = [
123+
{"params": []} for _ in range(len(optimizer._param_groups))
122124
]
123-
local_fused_params = self._rank2fused[self._sharding_rank]
124-
apply_decay_param_fun = lambda x: x in decay_params
125-
126-
all_fused_params = []
127-
for v in self._rank2fused.values():
128-
all_fused_params += v
129-
self._parameter_list = all_fused_params
130-
self._param_groups = all_fused_params
125+
for idx, pg in enumerate(optimizer._param_groups):
126+
param_groups[idx].update(
127+
{k: v for k, v in pg.items() if k != 'params'}
128+
)
129+
for param in self._rank2params[self._sharding_rank]:
130+
group_id = self._param_2_group_id[id(param)]
131+
param_groups[group_id]['params'].append(param)
131132

132-
self._set_inner_opt_attr('_parameter_list', local_fused_params)
133-
self._set_inner_opt_attr('_param_groups', local_fused_params)
134-
if self.comm_overlap:
135-
# Only set local param for check finite when comm overlap.
136-
# Under comm overlap, all grads will be communicated before check_finite.
137-
# Therefore, each sharding rank can get all grads' info at check_finite.
138-
# Without comm overlap, all grads will be communicated after check_finite,
139-
# which means each sharding rank should do check_finite to all grads.
140-
self._local_parameter_list = local_fused_params
141-
origin_decay_param_fun = getattr(
142-
self._inner_opt, '_apply_decay_param_fun', None
133+
self._set_inner_opt_attr('_param_groups', param_groups)
134+
self._set_inner_opt_attr(
135+
'_parameter_list', self._rank2params[self._sharding_rank]
143136
)
144-
if origin_decay_param_fun is not None:
145-
self._set_inner_opt_attr(
146-
'_apply_decay_param_fun', apply_decay_param_fun
137+
self._param_groups = self._parameter_list
138+
else:
139+
if not self.tensor_fusion and not self.comm_overlap:
140+
local_params = self._rank2params[self._sharding_rank]
141+
self._set_inner_opt_attr('_parameter_list', local_params)
142+
self._set_inner_opt_attr('_param_groups', local_params)
143+
else:
144+
self._tensor_fusion()
145+
146+
decay_params = [
147+
p.name for p in self._rank2decay[self._sharding_rank]
148+
]
149+
local_fused_params = self._rank2fused[self._sharding_rank]
150+
apply_decay_param_fun = lambda x: x in decay_params
151+
152+
all_fused_params = []
153+
for v in self._rank2fused.values():
154+
all_fused_params += v
155+
self._parameter_list = all_fused_params
156+
self._param_groups = all_fused_params
157+
158+
self._set_inner_opt_attr('_parameter_list', local_fused_params)
159+
self._set_inner_opt_attr('_param_groups', local_fused_params)
160+
if self.comm_overlap:
161+
# Only set local param for check finite when comm overlap.
162+
# Under comm overlap, all grads will be communicated before check_finite.
163+
# Therefore, each sharding rank can get all grads' info at check_finite.
164+
# Without comm overlap, all grads will be communicated after check_finite,
165+
# which means each sharding rank should do check_finite to all grads.
166+
self._local_parameter_list = local_fused_params
167+
origin_decay_param_fun = getattr(
168+
self._inner_opt, '_apply_decay_param_fun', None
147169
)
148-
# Note: during the tensor fusion for parameters, the allocator will apply for
149-
# some extra GPU memory for the fused big paramters. This extra GPU memory will
150-
# be useless at once the fusion has done. But the Paddle's allocator won't
151-
# release those memory, it will hold that part in the memory poll. So after
152-
# tensor fusion, the 'reserved' memory will increase but the 'allocate' memory
153-
# won't change. To avoid failure on some other applications (such as some nvtx
154-
# operations), here we manulay let the allocator release the cached memory.
155-
paddle.device.cuda.empty_cache()
170+
if origin_decay_param_fun is not None:
171+
self._set_inner_opt_attr(
172+
'_apply_decay_param_fun', apply_decay_param_fun
173+
)
174+
# Note: during the tensor fusion for parameters, the allocator will apply for
175+
# some extra GPU memory for the fused big paramters. This extra GPU memory will
176+
# be useless at once the fusion has done. But the Paddle's allocator won't
177+
# release those memory, it will hold that part in the memory poll. So after
178+
# tensor fusion, the 'reserved' memory will increase but the 'allocate' memory
179+
# won't change. To avoid failure on some other applications (such as some nvtx
180+
# operations), here we manulay let the allocator release the cached memory.
181+
paddle.device.cuda.empty_cache()
156182

157183
def clear_grad(self, set_to_zero=True):
158184
"""
@@ -331,6 +357,9 @@ def minimize(
331357
# NOTE in dygraph mode, the only different between step and minimize is that minimize
332358
# allow user to customize the parameters for updating on each step
333359

360+
assert (
361+
not self._using_param_groups
362+
), "minimize() is not support if using param_groups"
334363
input_param_names = {param.name for param in parameters}
335364
parameters = list(
336365
filter(
@@ -356,14 +385,12 @@ def step(self):
356385
# otherwise the self._inner_opt will only grad_clip the self._rank2params[self._sharding_rank] params
357386
# TODO(pangengzheng): remove the hacked grad_clip codes here when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
358387
origin_clip = self._inner_opt._grad_clip
359-
target_param_list = (
360-
self._origin_parameter_list
361-
if (not self.tensor_fusion or not self.fuse_optimizer)
362-
else self._parameter_list
363-
)
364-
if not isinstance(target_param_list[0], dict):
388+
if (
389+
not isinstance(self._parameter_list[0], dict)
390+
or not self._using_param_groups
391+
):
365392
params_grads = []
366-
for param in target_param_list:
393+
for param in self._parameter_list:
367394
if (
368395
hasattr(param, "regularizer")
369396
and param.regularizer is not None
@@ -398,6 +425,35 @@ def step(self):
398425
if g_shard_norm_align_dp:
399426
# restore the grad clip
400427
self._set_inner_opt_attr('_grad_clip', origin_clip)
428+
else:
429+
# optimize parameters in groups
430+
for param_group in self._inner_opt._param_groups:
431+
params_grads = defaultdict(lambda: [])
432+
433+
# TODO(shenliang03): support ClipGradByGlobalNorm in sharding when using param_groups
434+
grad_clip = param_group['grad_clip']
435+
assert not isinstance(
436+
grad_clip, ClipGradByGlobalNorm
437+
), "ClipGradByGlobalNorm is not support if using param_groups in sharding"
438+
439+
for param in param_group['params']:
440+
if param.stop_gradient:
441+
continue
442+
443+
grad_var = param._grad_ivar()
444+
if (
445+
hasattr(param, "main_grad")
446+
and param.main_grad is not None
447+
):
448+
grad_var = param.main_grad
449+
450+
params_grads['params'].append((param, grad_var))
451+
params_grads.update(
452+
{k: v for k, v in param_group.items() if k != 'params'}
453+
)
454+
self._apply_optimize(
455+
loss=None, startup_program=None, params_grads=params_grads
456+
)
401457

402458
# sync parameters across sharding ranks
403459
self._sharding_sync_parameters()

test/collective/fleet/hybrid_parallel_sharding_model.py

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import numpy as np
2020

2121
import paddle
22-
import paddle.distributed as dist
2322
from paddle.distributed import fleet
2423
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
2524
DygraphShardingOptimizer,
@@ -65,72 +64,6 @@ def parallel_matmul(lm_output, logit_weights, parallel_output):
6564
return logits
6665

6766

68-
class SimpleMPNet(paddle.nn.Layer):
69-
def __init__(
70-
self,
71-
vocab_size,
72-
hidden_size,
73-
inner_size,
74-
output_size,
75-
np_fc1,
76-
np_fc2,
77-
mp_id,
78-
):
79-
super().__init__()
80-
81-
if mp_id == 0:
82-
init_fc1_data = np_fc1[:, : (inner_size // 2)]
83-
init_fc2_data = np_fc2[: (inner_size // 2), :]
84-
else:
85-
init_fc1_data = np_fc1[:, (inner_size // 2) :]
86-
init_fc2_data = np_fc2[(inner_size // 2) :, :]
87-
88-
self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
89-
hidden_size,
90-
inner_size,
91-
weight_attr=paddle.framework.ParamAttr(
92-
initializer=paddle.nn.initializer.Assign(init_fc1_data)
93-
),
94-
gather_output=False,
95-
has_bias=True,
96-
)
97-
98-
self.linear2 = fleet.meta_parallel.RowParallelLinear(
99-
inner_size,
100-
hidden_size,
101-
weight_attr=paddle.framework.ParamAttr(
102-
initializer=paddle.nn.initializer.Assign(init_fc2_data)
103-
),
104-
input_is_parallel=True,
105-
has_bias=True,
106-
)
107-
108-
self.linear3 = paddle.nn.Linear(
109-
hidden_size,
110-
output_size,
111-
weight_attr=paddle.framework.ParamAttr(
112-
initializer=paddle.nn.initializer.Constant(0.0)
113-
),
114-
bias_attr=paddle.framework.ParamAttr(
115-
initializer=paddle.nn.initializer.Constant(0.0)
116-
),
117-
)
118-
119-
self.embedding = fleet.meta_parallel.VocabParallelEmbedding(
120-
vocab_size,
121-
hidden_size,
122-
weight_attr=paddle.nn.initializer.Constant(value=0.5),
123-
)
124-
125-
def forward(self, x):
126-
x = self.embedding(x)
127-
x = self.linear1(x)
128-
x = self.linear2(x)
129-
x = self.linear3(x)
130-
x = parallel_matmul(x, self.embedding.weight, False)
131-
return x
132-
133-
13467
class SimpleDPNet(paddle.nn.Layer):
13568
def __init__(
13669
self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
@@ -240,12 +173,6 @@ def build_optimizer(self, model, strategy=None, Optimizer="adam"):
240173
return optimizer
241174

242175
def build_model_optimizer(self, Optimizer="adam", amp_level=None):
243-
hcg = fleet.get_hybrid_communicate_group()
244-
word_size = hcg.get_model_parallel_world_size()
245-
sharding_id = hcg.get_sharding_parallel_rank()
246-
dp_id = hcg.get_data_parallel_rank()
247-
rank_id = dist.get_rank()
248-
249176
np_fc1 = np.random.random_sample((hidden_size, inner_size))
250177
np_fc2 = np.random.random_sample((inner_size, hidden_size))
251178

test/collective/fleet/test_parallel_dygraph_sharding_parallel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def test_hybrid_parallel_sharding_tensor_fusion_amp(self):
4545
os.environ["FLAGS_shard_split_param"] = "0"
4646
self.run_mnist_2gpu('hybrid_parallel_sharding_model_with_fusion_amp.py')
4747

48+
def test_hybrid_parallel_sharding_param_group(self):
49+
# test shard grad reduce
50+
os.environ["FLAGS_shard_split_param"] = "0"
51+
self.run_mnist_2gpu('hybrid_parallel_sharding_param_group.py')
52+
4853
def test_hybrid_parallel_sharding_state_dict(self):
4954
os.environ["FLAGS_shard_split_param"] = "0"
5055
self.run_mnist_2gpu('hybrid_parallel_sharding_state_dict.py')

0 commit comments

Comments
 (0)