@@ -64,6 +64,29 @@ def forward(ctx, i, w):
6464 @staticmethod
6565 def backward (ctx , grad_output ):
6666 (w ,) = ctx .saved_tensors
67+ """
68+ A[m,k] @ B[k,n] -> O[m,n]
69+ grad_o[m,n] @ B.t()[n,k] -> grad_a[m,k]
70+ looks right..
71+ getting
72+ [rank4]:[rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/pipeline_parallel.py", line 67, in backward
73+ [rank4]:[rank4]: grad_input = grad_output.mm(w.t())
74+ [rank4]:[rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/pipeline_parallel.py", line 88, in split_mm
75+ [rank4]:[rank4]: return MmPassThrough.apply(i1, w1)
76+ [rank4]:[rank4]: File "/data/users/whc/pytorch/torch/autograd/function.py", line 583, in apply
77+ [rank4]:[rank4]: return super().apply(*args, **kwargs) # type: ignore[misc]
78+ [rank4]:[rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/pipeline_parallel.py", line 74, in forward
79+ [rank4]:[rank4]: return torch.mm(x, y)
80+ [rank4]:[rank4]: RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x2816 and 2048x2816)
81+
82+ [rank4]:[rank4]: RuntimeError:
83+ [rank4]:[rank4]: Failed to run stage backward:
84+ [rank4]:[rank4]: Stage output: ('Tensor(torch.Size([1, 4096, 2048]), grad=True, dtype=torch.bfloat16)',)
85+ [rank4]:[rank4]: Output gradient: ('Tensor(torch.Size([1, 4096, 2048]), grad=False, dtype=torch.bfloat16)',)
86+ [rank4]:[rank4]: Input: ['Tensor(torch.Size([1, 4096, 2048]), grad=True, dtype=torch.bfloat16)']
87+ [rank4]:[rank4]:
88+ """
89+ logger .error (f"MmSeparateInputGrad backward: { grad_output .shape = } , { w .t ().shape = } " )
6790 grad_input = grad_output .mm (w .t ())
6891 return grad_input , None
6992
@@ -75,13 +98,14 @@ def forward(ctx, x, y):
7598
7699 @staticmethod
77100 def backward (ctx , gO ):
101+ # TODO(whc) - claude first wrote it this way and later tried to return None, None, i'm not sure which is correct
78102 return gO , gO
79103
80104 def split_mm (i , w ):
81- print ("split mul" )
82105 # Apply the pass-through node. y is passed to this node so that it can be
83106 # saved for backward, but detach because we don't want to actually build
84107 # this edge of the graph
108+ logger .error (f"split_mm forward: { i .shape = } , { w .shape = } " )
85109 w1 = MmSeparateWeightGrad .apply (i .detach (), w )
86110 i1 = MmSeparateInputGrad .apply (i , w .detach ())
87111 return MmPassThrough .apply (i1 , w1 )
@@ -138,7 +162,6 @@ def backward(ctx, gO):
138162 return gO , gO , gO , None , None
139163
140164 def split_addmm (bias , mat1 , mat2 , * , beta = 1 , alpha = 1 ):
141- print ("split addmm" )
142165 mat2_1 = AddmmSeparateMat2Grad .apply (mat1 .detach (), mat2 , alpha )
143166 mat1_1 = AddmmSeparateMat1Grad .apply (mat1 , mat2 .detach (), alpha )
144167 bias_1 = AddmmSeparateBiasGrad .apply (bias , beta )
@@ -197,7 +220,6 @@ def backward(ctx, gO):
197220 return gO , None , gO , None
198221
199222 def split_rms_norm (input , normalized_shape , weight = None , eps = None ):
200- print ("split rms_norm" )
201223 weight_1 = RmsNormSeparateWeightGrad .apply (
202224 input .detach (), normalized_shape , weight , eps
203225 )
@@ -255,7 +277,6 @@ def backward(ctx, gO):
255277 return gO , gO , None , None , None
256278
257279 def split_grouped_mm (input , mat2 , offs = None , bias = None , out_dtype = None ):
258- print ("split grouped_mm" )
259280 mat2_1 = GroupedMmSeparateMat2Grad .apply (
260281 input .detach (), mat2 , offs , bias , out_dtype
261282 )
0 commit comments