@@ -173,91 +173,6 @@ def fused_linear_cross_entropy_forward_megatron_chunked(
173173
174174 return loss , None , grad_input , grad_weight , grad_bias
175175
176- def fused_linear_cross_entropy_forward_megatron (
177- _input ,
178- weight ,
179- target ,
180- bias = None ,
181- reduction = "none" ,
182- ):
183- device = _input .device
184- BT , H = _input .shape
185- V = weight .shape [0 ]
186-
187- grad_weight = torch .zeros_like (weight , device = device ) if weight .requires_grad else None
188- grad_input = torch .zeros_like (_input , device = device )
189- grad_bias = torch .zeros_like (bias , device = device ) if bias is not None else None
190- # we use fp32 for loss accumulator
191- loss_1d = torch .zeros (BT , dtype = torch .float32 , device = device )
192-
193- # TODO: evaluate how CUDA synchronization caused by .item() affects the speed
194- rank = get_tensor_model_parallel_rank ()
195- world_size = get_tensor_model_parallel_world_size ()
196- vocab_start , vocab_end = VocabUtility .vocab_range_from_per_partition_vocab_size (V , rank , world_size )
197-
198- target_mask = (target < vocab_start ) | (target >= vocab_end )
199- adjusted_target = target .clone () - vocab_start # relative id
200- adjusted_target [target_mask ] = 0
201- adjusted_target_1d = adjusted_target .view (- 1 )
202-
203- # input
204- # when doing matmul, use the original precision
205- logits = (_input @ weight .t ()).float () # chunk_size x V
206- if bias is not None :
207- logits = logits + bias
208-
209- # # ensure _input and target are contiguous
210- # logits_chunk = logits_chunk.contiguous() # [chunk_size, vocab_size]
211- # target_chunk = target_chunk.contiguous() # [chunk_size]
212-
213- max_logits = torch .max (logits , dim = - 1 )[0 ]
214- torch .distributed .all_reduce (max_logits , op = torch .distributed .ReduceOp .MAX , group = get_tensor_model_parallel_group (), async_op = False )
215- logits = logits - max_logits .unsqueeze (- 1 )
216-
217- sum_exp_logits = torch .sum (torch .exp (logits ), dim = - 1 )
218- torch .distributed .all_reduce (sum_exp_logits , op = torch .distributed .ReduceOp .SUM , group = get_tensor_model_parallel_group (), async_op = False )
219-
220-
221- predicted_logits = logits [torch .arange (BT , device = logits .device ), adjusted_target_1d ]
222- predicted_logits [target_mask ] = 0.0
223- handle_predicted_logits = torch .distributed .all_reduce (predicted_logits , op = torch .distributed .ReduceOp .SUM , group = get_tensor_model_parallel_group (), async_op = True )
224-
225- # Compute gradient
226- grad_logits = torch .exp (logits ).div_ (sum_exp_logits .unsqueeze (- 1 ))
227- grad_logits [torch .arange (BT , device = grad_logits .device ), adjusted_target_1d ] -= 1.0 - target_mask .float () # chunk_size x V
228- grad_input = grad_logits .to (dtype = torch .half ) @ weight
229- torch .distributed .all_reduce (grad_input , group = get_tensor_model_parallel_group (), async_op = False )
230-
231- if grad_weight is not None :
232- torch .addmm (
233- input = grad_weight ,
234- mat1 = grad_logits .t ().to (
235- _input .dtype
236- ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
237- mat2 = _input ,
238- out = grad_weight ,
239- alpha = 1.0 ,
240- beta = 1.0 ,
241- )
242- if bias is not None :
243- torch .add (
244- input = grad_bias ,
245- other = grad_logits .sum (dim = 0 ),
246- out = grad_bias ,
247- alpha = 1.0 ,
248- )
249- handle_predicted_logits .wait ()
250- loss_chunk = torch .log (sum_exp_logits ) - predicted_logits
251- loss_1d = loss_chunk
252-
253- if reduction == "none" :
254- loss = loss_1d
255- else :
256- loss = torch .sum (loss_1d )
257-
258- return loss , None , grad_input , grad_weight , grad_bias
259-
260-
261176def fused_linear_cross_entropy_backward (grad_output , grad_input , grad_weight , grad_bias ):
262177 # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
263178 if not torch .equal (grad_output , torch .tensor (1.0 , device = grad_output .device )):
0 commit comments