We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent a33a4f3 commit a5d87f5Copy full SHA for a5d87f5
1 file changed
llm/llama/fused_layers.py
@@ -58,16 +58,18 @@ def backward(ctx, y_grad):
58
59
if hasattr(weight, "main_grad") and hasattr(bias, "main_grad"):
60
weight.main_grad, bias.main_grad = _C_ops.fused_linear_param_grad_add(
61
- x, y_grad, weight.main_grad, bias.main_grad, True
+ x, y_grad, weight.main_grad, bias.main_grad, True, True
62
)
63
return x_grad, None, None
64
else:
65
if weight.grad is not None:
66
assert bias.grad is not None
67
- weight.grad, bias.grad = _C_ops.fused_linear_param_grad_add(x, y_grad, weight.grad, bias.grad, False)
+ weight.grad, bias.grad = _C_ops.fused_linear_param_grad_add(
68
+ x, y_grad, weight.grad, bias.grad, False, True
69
+ )
70
71
- weight_grad, bias_grad = _C_ops.fused_linear_param_grad_add(x, y_grad, None, None, False)
72
+ weight_grad, bias_grad = _C_ops.fused_linear_param_grad_add(x, y_grad, None, None, False, True)
73
return x_grad, weight_grad, bias_grad
74
75
0 commit comments