Skip to content

Commit a5d87f5

Browse files
authored
[BugFix] Fix FusedLinearWithGradAdd usage (#8178)
1 parent a33a4f3 commit a5d87f5

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

llm/llama/fused_layers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,18 @@ def backward(ctx, y_grad):
5858

5959
if hasattr(weight, "main_grad") and hasattr(bias, "main_grad"):
6060
weight.main_grad, bias.main_grad = _C_ops.fused_linear_param_grad_add(
61-
x, y_grad, weight.main_grad, bias.main_grad, True
61+
x, y_grad, weight.main_grad, bias.main_grad, True, True
6262
)
6363
return x_grad, None, None
6464
else:
6565
if weight.grad is not None:
6666
assert bias.grad is not None
67-
weight.grad, bias.grad = _C_ops.fused_linear_param_grad_add(x, y_grad, weight.grad, bias.grad, False)
67+
weight.grad, bias.grad = _C_ops.fused_linear_param_grad_add(
68+
x, y_grad, weight.grad, bias.grad, False, True
69+
)
6870
return x_grad, None, None
6971
else:
70-
weight_grad, bias_grad = _C_ops.fused_linear_param_grad_add(x, y_grad, None, None, False)
72+
weight_grad, bias_grad = _C_ops.fused_linear_param_grad_add(x, y_grad, None, None, False, True)
7173
return x_grad, weight_grad, bias_grad
7274

7375

0 commit comments

Comments
 (0)