1414
1515######
1616import os
17+ from collections import defaultdict
1718from distutils .util import strtobool
1819from functools import reduce
1920
2223from paddle .base .dygraph import base as imperative_base
2324from paddle .base .framework import EagerParamBase
2425from paddle .distributed import fleet
26+ from paddle .nn import ClipGradByGlobalNorm
2527
2628from ...utils .log_util import logger
2729from ...utils .tensor_fusion_helper import (
@@ -62,21 +64,27 @@ class DygraphShardingOptimizer:
6264 # 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm
6365
6466 def __init__ (self , optimizer , hcg ):
65- logger .info ("init DygraphShardingOptimizer" )
66- # TODO(pangengzheng): support param_groups
67- if isinstance (optimizer ._parameter_list [0 ], dict ):
68- raise TypeError (
69- "Do not support param_groups now, please set optimizer._parameter_list as a list of Parameter"
70- )
7167 if not hasattr (optimizer , '_apply_optimize' ) or not callable (
7268 optimizer ._apply_optimize
7369 ):
7470 raise ValueError (
7571 "the optimzier object should have _apply_optimize function"
7672 )
77- # the self._parameter_list holds the whole model paramters
78- self ._parameter_list = optimizer ._parameter_list
79- self ._origin_parameter_list = self ._parameter_list
73+
74+ self ._using_param_groups = isinstance (
75+ optimizer ._parameter_list [0 ], dict
76+ )
77+
78+ self ._parameter_list = []
79+ self ._param_2_group_id = {}
80+ if self ._using_param_groups :
81+ for idx , param_group in enumerate (optimizer ._param_groups ):
82+ for param in param_group ['params' ]:
83+ self ._param_2_group_id [id (param )] = idx
84+ self ._parameter_list .append (param )
85+ else :
86+ self ._parameter_list = optimizer ._parameter_list
87+
8088 self ._inner_opt = optimizer
8189 self ._hcg = hcg
8290 self ._sharding_world_size = self ._hcg .get_sharding_parallel_world_size ()
@@ -110,49 +118,67 @@ def __init__(self, optimizer, hcg):
110118 self ._rank2params = self ._partition_parameters ()
111119 self ._param2rank = self ._map_param_to_rank ()
112120
113- if not self .tensor_fusion and not self .comm_overlap :
114- local_params = self ._rank2params [self ._sharding_rank ]
115- self ._set_inner_opt_attr ('_parameter_list' , local_params )
116- self ._set_inner_opt_attr ('_param_groups' , local_params )
117- else :
118- self ._tensor_fusion ()
119-
120- decay_params = [
121- p .name for p in self ._rank2decay [self ._sharding_rank ]
121+ if self ._using_param_groups :
122+ param_groups = [
123+ {"params" : []} for _ in range (len (optimizer ._param_groups ))
122124 ]
123- local_fused_params = self ._rank2fused [self ._sharding_rank ]
124- apply_decay_param_fun = lambda x : x in decay_params
125-
126- all_fused_params = []
127- for v in self ._rank2fused .values ():
128- all_fused_params += v
129- self ._parameter_list = all_fused_params
130- self ._param_groups = all_fused_params
125+ for idx , pg in enumerate (optimizer ._param_groups ):
126+ param_groups [idx ].update (
127+ {k : v for k , v in pg .items () if k != 'params' }
128+ )
129+ for param in self ._rank2params [self ._sharding_rank ]:
130+ group_id = self ._param_2_group_id [id (param )]
131+ param_groups [group_id ]['params' ].append (param )
131132
132- self ._set_inner_opt_attr ('_parameter_list' , local_fused_params )
133- self ._set_inner_opt_attr ('_param_groups' , local_fused_params )
134- if self .comm_overlap :
135- # Only set local param for check finite when comm overlap.
136- # Under comm overlap, all grads will be communicated before check_finite.
137- # Therefore, each sharding rank can get all grads' info at check_finite.
138- # Without comm overlap, all grads will be communicated after check_finite,
139- # which means each sharding rank should do check_finite to all grads.
140- self ._local_parameter_list = local_fused_params
141- origin_decay_param_fun = getattr (
142- self ._inner_opt , '_apply_decay_param_fun' , None
133+ self ._set_inner_opt_attr ('_param_groups' , param_groups )
134+ self ._set_inner_opt_attr (
135+ '_parameter_list' , self ._rank2params [self ._sharding_rank ]
143136 )
144- if origin_decay_param_fun is not None :
145- self ._set_inner_opt_attr (
146- '_apply_decay_param_fun' , apply_decay_param_fun
137+ self ._param_groups = self ._parameter_list
138+ else :
139+ if not self .tensor_fusion and not self .comm_overlap :
140+ local_params = self ._rank2params [self ._sharding_rank ]
141+ self ._set_inner_opt_attr ('_parameter_list' , local_params )
142+ self ._set_inner_opt_attr ('_param_groups' , local_params )
143+ else :
144+ self ._tensor_fusion ()
145+
146+ decay_params = [
147+ p .name for p in self ._rank2decay [self ._sharding_rank ]
148+ ]
149+ local_fused_params = self ._rank2fused [self ._sharding_rank ]
150+ apply_decay_param_fun = lambda x : x in decay_params
151+
152+ all_fused_params = []
153+ for v in self ._rank2fused .values ():
154+ all_fused_params += v
155+ self ._parameter_list = all_fused_params
156+ self ._param_groups = all_fused_params
157+
158+ self ._set_inner_opt_attr ('_parameter_list' , local_fused_params )
159+ self ._set_inner_opt_attr ('_param_groups' , local_fused_params )
160+ if self .comm_overlap :
161+ # Only set local param for check finite when comm overlap.
162+ # Under comm overlap, all grads will be communicated before check_finite.
163+ # Therefore, each sharding rank can get all grads' info at check_finite.
164+ # Without comm overlap, all grads will be communicated after check_finite,
165+ # which means each sharding rank should do check_finite to all grads.
166+ self ._local_parameter_list = local_fused_params
167+ origin_decay_param_fun = getattr (
168+ self ._inner_opt , '_apply_decay_param_fun' , None
147169 )
148- # Note: during the tensor fusion for parameters, the allocator will apply for
149- # some extra GPU memory for the fused big paramters. This extra GPU memory will
150- # be useless at once the fusion has done. But the Paddle's allocator won't
151- # release those memory, it will hold that part in the memory poll. So after
152- # tensor fusion, the 'reserved' memory will increase but the 'allocate' memory
153- # won't change. To avoid failure on some other applications (such as some nvtx
154- # operations), here we manulay let the allocator release the cached memory.
155- paddle .device .cuda .empty_cache ()
170+ if origin_decay_param_fun is not None :
171+ self ._set_inner_opt_attr (
172+ '_apply_decay_param_fun' , apply_decay_param_fun
173+ )
174+ # Note: during the tensor fusion for parameters, the allocator will apply for
175+ # some extra GPU memory for the fused big paramters. This extra GPU memory will
176+ # be useless at once the fusion has done. But the Paddle's allocator won't
177+ # release those memory, it will hold that part in the memory poll. So after
178+ # tensor fusion, the 'reserved' memory will increase but the 'allocate' memory
179+ # won't change. To avoid failure on some other applications (such as some nvtx
180+ # operations), here we manulay let the allocator release the cached memory.
181+ paddle .device .cuda .empty_cache ()
156182
157183 def clear_grad (self , set_to_zero = True ):
158184 """
@@ -331,6 +357,9 @@ def minimize(
331357 # NOTE in dygraph mode, the only different between step and minimize is that minimize
332358 # allow user to customize the parameters for updating on each step
333359
360+ assert (
361+ not self ._using_param_groups
362+ ), "minimize() is not support if using param_groups"
334363 input_param_names = {param .name for param in parameters }
335364 parameters = list (
336365 filter (
@@ -356,14 +385,12 @@ def step(self):
356385 # otherwise the self._inner_opt will only grad_clip the self._rank2params[self._sharding_rank] params
357386 # TODO(pangengzheng): remove the hacked grad_clip codes here when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
358387 origin_clip = self ._inner_opt ._grad_clip
359- target_param_list = (
360- self ._origin_parameter_list
361- if (not self .tensor_fusion or not self .fuse_optimizer )
362- else self ._parameter_list
363- )
364- if not isinstance (target_param_list [0 ], dict ):
388+ if (
389+ not isinstance (self ._parameter_list [0 ], dict )
390+ or not self ._using_param_groups
391+ ):
365392 params_grads = []
366- for param in target_param_list :
393+ for param in self . _parameter_list :
367394 if (
368395 hasattr (param , "regularizer" )
369396 and param .regularizer is not None
@@ -398,6 +425,35 @@ def step(self):
398425 if g_shard_norm_align_dp :
399426 # restore the grad clip
400427 self ._set_inner_opt_attr ('_grad_clip' , origin_clip )
428+ else :
429+ # optimize parameters in groups
430+ for param_group in self ._inner_opt ._param_groups :
431+ params_grads = defaultdict (lambda : [])
432+
433+ # TODO(shenliang03): support ClipGradByGlobalNorm in sharding when using param_groups
434+ grad_clip = param_group ['grad_clip' ]
435+ assert not isinstance (
436+ grad_clip , ClipGradByGlobalNorm
437+ ), "ClipGradByGlobalNorm is not support if using param_groups in sharding"
438+
439+ for param in param_group ['params' ]:
440+ if param .stop_gradient :
441+ continue
442+
443+ grad_var = param ._grad_ivar ()
444+ if (
445+ hasattr (param , "main_grad" )
446+ and param .main_grad is not None
447+ ):
448+ grad_var = param .main_grad
449+
450+ params_grads ['params' ].append ((param , grad_var ))
451+ params_grads .update (
452+ {k : v for k , v in param_group .items () if k != 'params' }
453+ )
454+ self ._apply_optimize (
455+ loss = None , startup_program = None , params_grads = params_grads
456+ )
401457
402458 # sync parameters across sharding ranks
403459 self ._sharding_sync_parameters ()
0 commit comments