@@ -41,6 +41,7 @@ def __init__(self,
4141 self .rank = tp ._HYBRID_PARALLEL_GROUP .get_model_parallel_rank ()
4242
4343 self .origin_num_embeddings = num_embeddings
44+ self .is_mp = (self .world_size > 1 )
4445
4546 per_part_size = (
4647 num_embeddings + self .world_size - 1 ) // self .world_size
@@ -50,16 +51,36 @@ def __init__(self,
5051 per_part_size += 1 # make the last row as the padding index
5152 self .per_part_size = per_part_size
5253
53- self .embedding = paddle .nn .Embedding (
54- per_part_size ,
55- embedding_dim ,
56- padding_idx = per_part_size - 1 ,
57- sparse = False ,
58- weight_attr = weight_attr ,
59- name = name )
60- self .embedding .weight .is_distributed = True
54+ self ._dtype = self ._helper .get_default_dtype ()
55+ self ._size = [per_part_size , embedding_dim ]
56+ self ._weight_attr = weight_attr
57+ self ._name = name
58+
59+ if self .is_mp :
60+ with get_rng_state_tracker ().rng_state ():
61+ self .weight = self .create_parameter (
62+ attr = self ._weight_attr ,
63+ shape = self ._size ,
64+ dtype = self ._dtype ,
65+ is_bias = False )
66+ self .weight [per_part_size - 1 ] = 0.0
67+ self .weight .is_distributed = True
68+ else :
69+ self .weight = self .create_parameter (
70+ attr = self ._weight_attr ,
71+ shape = [num_embeddings , embedding_dim ],
72+ dtype = self ._dtype ,
73+ is_bias = False )
6174
6275 def forward (self , x ):
76+ if not self .is_mp :
77+ return F .embedding (
78+ x ,
79+ weight = self .weight ,
80+ padding_idx = None ,
81+ sparse = False ,
82+ name = self ._name )
83+
6384 origin_input_shape = x .shape
6485 if len (origin_input_shape ) == 2 :
6586 x = paddle .unsqueeze (x , axis = - 1 )
@@ -72,13 +93,18 @@ def forward(self, x):
7293 if len (origin_input_shape ) == 2 :
7394 x_shard = paddle .squeeze (x_shard , axis = - 1 )
7495
75- emb_out = self .embedding (x_shard )
76- if self .world_size > 1 :
77- emb_out = paddle .distributed .collective ._mp_allreduce (
78- emb_out ,
79- group = self .model_parallel_group ,
80- use_calc_stream = True ,
81- use_model_parallel = True )
96+ emb_out = F .embedding (
97+ x_shard ,
98+ weight = self .weight ,
99+ padding_idx = self .per_part_size - 1 ,
100+ sparse = False ,
101+ name = self ._name )
102+
103+ emb_out = paddle .distributed .collective ._mp_allreduce (
104+ emb_out ,
105+ group = self .model_parallel_group ,
106+ use_calc_stream = True ,
107+ use_model_parallel = True )
82108 return emb_out
83109
84110
@@ -96,8 +122,9 @@ def __init__(self,
96122 )
97123 self .world_size = tp ._HYBRID_PARALLEL_GROUP .get_model_parallel_world_size (
98124 )
125+ self ._name = name
126+ self .is_mp = (self .world_size > 1 )
99127
100- self .name = name
101128 self .gather_output = gather_output
102129 assert out_features % self .world_size == 0 , (
103130 "Number of column of the weight for linear ({}) must be"
@@ -108,29 +135,45 @@ def __init__(self,
108135 self ._weight_attr = weight_attr
109136 self ._dtype = self ._helper .get_default_dtype ()
110137
111- self .weight = self .create_parameter (
112- shape = [in_features , self .output_size_per_partition ],
113- attr = self ._weight_attr ,
114- dtype = self ._dtype )
138+ if self .is_mp :
139+ with get_rng_state_tracker ().rng_state ():
140+ self .weight = self .create_parameter (
141+ shape = [in_features , self .output_size_per_partition ],
142+ attr = self ._weight_attr ,
143+ dtype = self ._dtype ,
144+ is_bias = False )
145+ else :
146+ self .weight = self .create_parameter (
147+ shape = [in_features , self .output_size_per_partition ],
148+ attr = self ._weight_attr ,
149+ dtype = self ._dtype ,
150+ is_bias = False )
151+
115152 self .weight .is_distributed = True
116153
117154 if has_bias :
118155 # initialize bias to zero like Megatron
119156 self .bias = self .create_parameter (
120157 shape = [self .output_size_per_partition ],
121158 attr = paddle .nn .initializer .Constant (value = 0.0 ),
122- dtype = self ._dtype )
159+ dtype = self ._dtype ,
160+ is_bias = True )
123161 self .bias .is_distributed = True
124162 else :
125163 self .bias = None
126164
127165 def forward (self , x ):
128166 # use inner api to process identity
129- input_parallel = paddle .distributed .collective ._c_identity (
130- x , group = self .model_parallel_group )
167+ if self .is_mp :
168+ input_parallel = paddle .distributed .collective ._c_identity (
169+ x , group = self .model_parallel_group )
170+ else :
171+ input_parallel = x
172+
131173 output_parallel = F .linear (
132- input_parallel , self .weight , self .bias , name = self .name )
133- if self .gather_output :
174+ input_parallel , self .weight , self .bias , name = self ._name )
175+
176+ if self .gather_output and self .is_mp :
134177 output = paddle .distributed .collective ._c_concat (
135178 output_parallel ,
136179 nranks = self .world_size ,
@@ -155,37 +198,49 @@ def __init__(self,
155198 self .input_is_parallel = input_is_parallel
156199 self ._weight_attr = weight_attr
157200 self ._dtype = self ._helper .get_default_dtype ()
158- self .name = name
201+ self ._name = name
159202
160203 self .model_parallel_group = tp ._HYBRID_PARALLEL_GROUP .get_model_parallel_group (
161204 )
162205 self .world_size = tp ._HYBRID_PARALLEL_GROUP .get_model_parallel_world_size (
163206 )
164207 self .rank = tp ._HYBRID_PARALLEL_GROUP .get_model_parallel_rank ()
165208
209+ self .is_mp = (self .world_size > 1 )
166210 assert in_features % self .world_size == 0 , (
167211 "Number of row of the weight for linear ({}) must be"
168212 " divisible by model parallel size ({})" .format (in_features ,
169213 self .world_size ))
170214
171215 self .input_size_per_partition = in_features // self .world_size
172216
173- self .weight = self .create_parameter (
174- shape = [self .input_size_per_partition , self .out_features ],
175- attr = self ._weight_attr ,
176- dtype = self ._dtype )
217+ if self .is_mp :
218+ with get_rng_state_tracker ().rng_state ():
219+ self .weight = self .create_parameter (
220+ shape = [self .input_size_per_partition , self .out_features ],
221+ attr = self ._weight_attr ,
222+ dtype = self ._dtype ,
223+ is_bias = False )
224+ else :
225+ self .weight = self .create_parameter (
226+ shape = [self .input_size_per_partition , self .out_features ],
227+ attr = self ._weight_attr ,
228+ dtype = self ._dtype ,
229+ is_bias = False )
230+
177231 self .weight .is_distributed = True
178232
179233 if has_bias :
180234 self .bias = self .create_parameter (
181235 shape = [self .out_features ],
182236 attr = paddle .nn .initializer .Constant (value = 0.0 ),
183- dtype = self ._dtype )
237+ dtype = self ._dtype ,
238+ is_bias = True )
184239 else :
185240 self .bias = None
186241
187242 def forward (self , x ):
188- if self .input_is_parallel :
243+ if self .input_is_parallel or ( not self . is_mp ) :
189244 input_parallel = x
190245 else :
191246 # split last dim
@@ -195,12 +250,16 @@ def forward(self, x):
195250 nranks = self .world_size ,
196251 group = self .model_parallel_group )
197252
198- output_parallel = F .linear (input_parallel , self .weight , name = self .name )
199- output_ = paddle .distributed .collective ._mp_allreduce (
200- output_parallel ,
201- group = self .model_parallel_group ,
202- use_calc_stream = True ,
203- use_model_parallel = True )
253+ output_parallel = F .linear (input_parallel , self .weight , name = self ._name )
254+
255+ if self .is_mp :
256+ output_ = paddle .distributed .collective ._mp_allreduce (
257+ output_parallel ,
258+ group = self .model_parallel_group ,
259+ use_calc_stream = True ,
260+ use_model_parallel = True )
261+ else :
262+ output_ = output_parallel
204263
205264 output = output_ + self .bias if self .bias is not None else output_
206265 return output
0 commit comments