@@ -497,6 +497,60 @@ def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape
497497 return (d_input , d_weight , d_bias )
498498
499499
500+ @register_decomposition (aten .native_batch_norm_backward )
501+ def native_batch_norm_backward (grad_out : Tensor , input : Tensor , weight : Optional [Tensor ], running_mean : Optional [Tensor ], running_var : Optional [Tensor ], save_mean : Optional [Tensor ], save_invstd : Optional [Tensor ], train : bool , eps : float , output_mask : List [bool ]) -> Tuple [Tensor , Optional [Tensor ], Optional [Tensor ]]:
502+ input_shape = input .shape
503+ input_rank = input .dim ()
504+ assert input_rank >= 2 , "rank of the input must be at least 2"
505+
506+ axis = 1
507+ num_features = prod (input_shape ) / input_shape [axis ]
508+ mean = save_mean
509+ invstd = save_invstd
510+ if train :
511+ assert save_mean is not None and save_invstd is not None , "when train=True, save_mean and save_invstd are required"
512+ else :
513+ mean = running_mean
514+ invstd = torch .rsqrt (running_var + eps )
515+
516+ broadcast_mask = [1 ] * input_rank
517+ broadcast_mask [axis ] = input_shape [axis ]
518+
519+ reduction_axes = []
520+ for i in range (input_rank ):
521+ if i != axis :
522+ reduction_axes .append (i )
523+
524+ mean = torch .reshape (mean , broadcast_mask )
525+ norm = 1.0 / num_features
526+ grad_output_sum = torch .sum (grad_out , reduction_axes )
527+ dot_p = torch .sum (grad_out * (input - mean ), reduction_axes )
528+
529+ grad_mean = torch .reshape (grad_output_sum * norm , broadcast_mask )
530+ proj_scale = torch .reshape (torch .mul (dot_p * norm , invstd * invstd ), broadcast_mask )
531+
532+ grad_scale = None
533+ if weight is None :
534+ grad_scale = torch .reshape (invstd , broadcast_mask ) * 1.0
535+ else :
536+ grad_scale = torch .reshape (invstd * weight , broadcast_mask )
537+ grad_input = None
538+ if train :
539+ proj = (input - mean ) * proj_scale
540+ grad_input = ((grad_out - proj ) - grad_mean ) * grad_scale
541+ else :
542+ grad_input = grad_out * grad_scale
543+
544+ grad_weight = None
545+ if output_mask [1 ]:
546+ grad_weight = dot_p * invstd
547+
548+ grad_bias = None
549+ if output_mask [2 ]:
550+ grad_bias = grad_output_sum
551+ return (grad_input , grad_weight , grad_bias )
552+
553+
500554@register_decomposition (aten .clamp_min )
501555def clamp_min (self : Tensor , min : float ):
502556 return torch .clamp (self , min = min )
0 commit comments