Skip to content

Commit 0dfc140

Browse files
wconstabH-Huang
authored andcommitted
debugging mm backwards shape error
ghstack-source-id: 7ee5b8a Pull Request resolved: pytorch#2035
1 parent ce6150e commit 0dfc140

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

torchtitan/distributed/pipeline_parallel.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,29 @@ def forward(ctx, i, w):
6464
@staticmethod
6565
def backward(ctx, grad_output):
6666
(w,) = ctx.saved_tensors
67+
"""
68+
A[m,k] @ B[k,n] -> O[m,n]
69+
grad_o[m,n] @ B.t()[n,k] -> grad_a[m,k]
70+
looks right..
71+
getting
72+
[rank4]:[rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/pipeline_parallel.py", line 67, in backward
73+
[rank4]:[rank4]: grad_input = grad_output.mm(w.t())
74+
[rank4]:[rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/pipeline_parallel.py", line 88, in split_mm
75+
[rank4]:[rank4]: return MmPassThrough.apply(i1, w1)
76+
[rank4]:[rank4]: File "/data/users/whc/pytorch/torch/autograd/function.py", line 583, in apply
77+
[rank4]:[rank4]: return super().apply(*args, **kwargs) # type: ignore[misc]
78+
[rank4]:[rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/pipeline_parallel.py", line 74, in forward
79+
[rank4]:[rank4]: return torch.mm(x, y)
80+
[rank4]:[rank4]: RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x2816 and 2048x2816)
81+
82+
[rank4]:[rank4]: RuntimeError:
83+
[rank4]:[rank4]: Failed to run stage backward:
84+
[rank4]:[rank4]: Stage output: ('Tensor(torch.Size([1, 4096, 2048]), grad=True, dtype=torch.bfloat16)',)
85+
[rank4]:[rank4]: Output gradient: ('Tensor(torch.Size([1, 4096, 2048]), grad=False, dtype=torch.bfloat16)',)
86+
[rank4]:[rank4]: Input: ['Tensor(torch.Size([1, 4096, 2048]), grad=True, dtype=torch.bfloat16)']
87+
[rank4]:[rank4]:
88+
"""
89+
logger.error(f"MmSeparateInputGrad backward: {grad_output.shape=}, {w.t().shape=}")
6790
grad_input = grad_output.mm(w.t())
6891
return grad_input, None
6992

@@ -75,13 +98,14 @@ def forward(ctx, x, y):
7598

7699
@staticmethod
77100
def backward(ctx, gO):
101+
# TODO(whc) - claude first wrote it this way and later tried to return None, None, i'm not sure which is correct
78102
return gO, gO
79103

80104
def split_mm(i, w):
81-
print("split mul")
82105
# Apply the pass-through node. y is passed to this node so that it can be
83106
# saved for backward, but detach because we don't want to actually build
84107
# this edge of the graph
108+
logger.error(f"split_mm forward: {i.shape=}, {w.shape=}")
85109
w1 = MmSeparateWeightGrad.apply(i.detach(), w)
86110
i1 = MmSeparateInputGrad.apply(i, w.detach())
87111
return MmPassThrough.apply(i1, w1)
@@ -138,7 +162,6 @@ def backward(ctx, gO):
138162
return gO, gO, gO, None, None
139163

140164
def split_addmm(bias, mat1, mat2, *, beta=1, alpha=1):
141-
print("split addmm")
142165
mat2_1 = AddmmSeparateMat2Grad.apply(mat1.detach(), mat2, alpha)
143166
mat1_1 = AddmmSeparateMat1Grad.apply(mat1, mat2.detach(), alpha)
144167
bias_1 = AddmmSeparateBiasGrad.apply(bias, beta)
@@ -197,7 +220,6 @@ def backward(ctx, gO):
197220
return gO, None, gO, None
198221

199222
def split_rms_norm(input, normalized_shape, weight=None, eps=None):
200-
print("split rms_norm")
201223
weight_1 = RmsNormSeparateWeightGrad.apply(
202224
input.detach(), normalized_shape, weight, eps
203225
)
@@ -255,7 +277,6 @@ def backward(ctx, gO):
255277
return gO, gO, None, None, None
256278

257279
def split_grouped_mm(input, mat2, offs=None, bias=None, out_dtype=None):
258-
print("split grouped_mm")
259280
mat2_1 = GroupedMmSeparateMat2Grad.apply(
260281
input.detach(), mat2, offs, bias, out_dtype
261282
)

torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ export_dtype = "float32"
6161
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]"
6262

6363
[activation_checkpoint]
64-
mode = "selective" # ["none", "selective", "full"]
64+
mode = "none" # ["none", "selective", "full"]
6565
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy
6666

6767
[compile]

0 commit comments

Comments
 (0)