Skip to content

Commit 77da910

Browse files
authored
fused linear grad add bug fix and perf optim (#56094)
* skip CopyOrAdd when tmp grad is None (#55679) * Optim fused linear grad add (#55927)
1 parent 9b317b2 commit 77da910

File tree

5 files changed

+36
-9
lines changed

5 files changed

+36
-9
lines changed

paddle/fluid/eager/accumulation/accumulation_node.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,10 @@ GradNodeAccumulation::operator()(
124124

125125
if (!weak_grad_.expired() && !is_new_grad) {
126126
auto grad = weak_grad_.lock();
127-
CopyOrAddTensor(grad.get(), grad_out, is_fake_empty_);
127+
if (grad_out.defined() && grad_out.initialized()) {
128+
CopyOrAddTensor(grad.get(), grad_out, is_fake_empty_);
129+
}
130+
// else { do nothing since there is no valid value in grad out tensor }
128131
is_fake_empty_ = false;
129132
}
130133

paddle/phi/api/yaml/fused_ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
support_dygraph_mode : true
4646

4747
- op : fused_linear_param_grad_add
48-
args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true)
48+
args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true, bool has_bias = true)
4949
output : Tensor(dweight_out), Tensor(dbias_out)
5050
infer_meta:
5151
func : FusedLinearParamGradAddInferMeta

paddle/phi/infermeta/multiary.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1259,6 +1259,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
12591259
const MetaTensor& dweight,
12601260
const MetaTensor& dbias,
12611261
bool multi_precision,
1262+
bool has_bias,
12621263
MetaTensor* dweight_out,
12631264
MetaTensor* dbias_out) {
12641265
const auto dtype = dout.dtype();
@@ -1302,7 +1303,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
13021303
? DataType::FLOAT32
13031304
: dtype;
13041305

1305-
if (dbias_out) {
1306+
if (has_bias && dbias_out) {
13061307
dbias_out->set_dims({weight_dims[1]});
13071308
dbias_out->set_dtype(multi_precision ? mp_dtype : dtype);
13081309
}

paddle/phi/infermeta/multiary.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
265265
const MetaTensor& dweight,
266266
const MetaTensor& dbias,
267267
bool multi_precision,
268+
bool has_bias,
268269
MetaTensor* dweight_out,
269270
MetaTensor* dbias_out);
270271

paddle/phi/kernels/fusion/gpu/fused_linear_param_grad_add_kernel.cu

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ void FusedLinearParamGradAddImpl(const Context &ctx,
4040
int64_t K,
4141
int64_t N,
4242
bool use_addto,
43+
bool has_bias,
4344
DenseTensor *dweight_out,
4445
DenseTensor *dbias_out) {
4546
constexpr bool kIsMultiPrecision = !std::is_same<T, MT>::value;
@@ -65,7 +66,7 @@ void FusedLinearParamGradAddImpl(const Context &ctx,
6566
use_addto);
6667
}
6768

68-
if (dbias_out == nullptr) return;
69+
if (!has_bias) return;
6970

7071
if (!fuse_bias_grad) {
7172
auto dout_copy = dout;
@@ -126,6 +127,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
126127
const paddle::optional<DenseTensor> &dweight,
127128
const paddle::optional<DenseTensor> &dbias,
128129
bool multi_precision,
130+
bool has_bias,
129131
DenseTensor *dweight_out,
130132
DenseTensor *dbias_out) {
131133
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
@@ -159,7 +161,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
159161
multi_precision = false;
160162
}
161163

162-
if (dbias_out) {
164+
if (has_bias && dbias_out) {
163165
ctx.template Alloc<T>(dbias_out);
164166
}
165167

@@ -176,18 +178,37 @@ void FusedLinearParamGradAdd(const Context &ctx,
176178
PrintMeta<kLogLevel>(dweight_out, "dweight_out");
177179
PrintMeta<kLogLevel>(dbias_out, "dbias_out");
178180
VLOG(kLogLevel) << "multi_precision = " << multi_precision;
181+
VLOG(kLogLevel) << "has_bias = " << has_bias;
179182
VLOG(kLogLevel) << "use_addto = " << use_addto;
180183
VLOG(kLogLevel) << "M = " << M;
181184
VLOG(kLogLevel) << "N = " << N;
182185
VLOG(kLogLevel) << "K = " << K;
183186
}
184187

185188
if (multi_precision) {
186-
FusedLinearParamGradAddImpl<T, MT, Context>(
187-
ctx, x, dout, dbias, M, K, N, use_addto, dweight_out, dbias_out);
189+
FusedLinearParamGradAddImpl<T, MT, Context>(ctx,
190+
x,
191+
dout,
192+
dbias,
193+
M,
194+
K,
195+
N,
196+
use_addto,
197+
has_bias,
198+
dweight_out,
199+
dbias_out);
188200
} else {
189-
FusedLinearParamGradAddImpl<T, T, Context>(
190-
ctx, x, dout, dbias, M, K, N, use_addto, dweight_out, dbias_out);
201+
FusedLinearParamGradAddImpl<T, T, Context>(ctx,
202+
x,
203+
dout,
204+
dbias,
205+
M,
206+
K,
207+
N,
208+
use_addto,
209+
has_bias,
210+
dweight_out,
211+
dbias_out);
191212
}
192213
}
193214

@@ -199,6 +220,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
199220
const paddle::optional<DenseTensor> &dweight,
200221
const paddle::optional<DenseTensor> &dbias,
201222
bool multi_precision,
223+
bool has_bias,
202224
DenseTensor *dweight_out,
203225
DenseTensor *dbias_out) {
204226
PADDLE_THROW(phi::errors::Unimplemented(

0 commit comments

Comments
 (0)