Skip to content

Commit 76c1bae

Browse files
author
zhangkaihuo
authored
[cherry pick] add op: fused_feedforward(backward) (#36730)
* add op: fused_feedforward(backward) (#35611) 这个PR是fused_feedforward反向的代码 相关kernel实现:fused_dropout_act_bias, fused_residual_dropout_bias, fused_layernorm_residual_dropout_bias fused_feedforward是一个融合算子,该算子对transformer模型的feed forward层的算子进行融合和封装,使得前端只呈现一个接口,通过融合减少部分访存和kernel launch的时间,以此提升性能。 * Move fused_attention and fused_feedforward functional api path to incubate (#36704) 将 #35905#35843 PR中新增的的python api接口移到incubate目录下。
1 parent 5b357e0 commit 76c1bae

File tree

9 files changed

+417
-34
lines changed

9 files changed

+417
-34
lines changed

paddle/fluid/operators/fused/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,8 @@ if (WITH_GPU OR WITH_ROCM)
8080
nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
8181
nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
8282

83-
8483
op_library(fused_feedforward_op)
8584
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_feedforward);\n")
86-
8785
# fused_attention_op
8886
op_library(fused_attention_op)
8987
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_attention);\n")

paddle/fluid/operators/fused/fused_feedforward_op.cc

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,154 @@ class FusedFeedForwardOpMaker : public framework::OpProtoAndCheckerMaker {
206206
}
207207
};
208208

209+
class FusedFeedForwardOpGrad : public framework::OperatorWithKernel {
210+
public:
211+
using framework::OperatorWithKernel::OperatorWithKernel;
212+
213+
protected:
214+
void InferShape(framework::InferShapeContext *ctx) const override {
215+
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("dropout1_is_test"), false,
216+
platform::errors::InvalidArgument(
217+
"GradOp is only callable when is_test is false"));
218+
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("dropout2_is_test"), false,
219+
platform::errors::InvalidArgument(
220+
"GradOp is only callable when is_test is false"));
221+
OP_INOUT_CHECK(ctx->HasInput("Dropout1Mask"), "Input", "Dropout1Mask",
222+
"FusedFeedForwardGrad");
223+
OP_INOUT_CHECK(ctx->HasInput("Dropout2Mask"), "Input", "Dropout1Mask",
224+
"FusedFeedForwardGrad");
225+
OP_INOUT_CHECK(ctx->HasInput("Linear1Out"), "Input", "Linear1Out",
226+
"FusedFeedForwardGrad");
227+
OP_INOUT_CHECK(ctx->HasInput("Ln1Out"), "Input", "Ln1Out",
228+
"FusedFeedForwardGrad");
229+
OP_INOUT_CHECK(ctx->HasInput("Dropout1Out"), "Input", "Dropout1Out",
230+
"FusedFeedForwardGrad");
231+
OP_INOUT_CHECK(ctx->HasInput("Dropout2Out"), "Input", "Dropout2Out",
232+
"FusedFeedForwardGrad");
233+
OP_INOUT_CHECK(ctx->HasInput("Linear1Weight"), "Input", "Linear1Weight",
234+
"FusedFeedForwardGrad");
235+
OP_INOUT_CHECK(ctx->HasInput("Linear2Weight"), "Input", "Linear2Weight",
236+
"FusedFeedForwardGrad");
237+
OP_INOUT_CHECK(ctx->HasInput("Ln1Mean"), "Input", "Ln1Mean",
238+
"FusedFeedForwardGrad");
239+
OP_INOUT_CHECK(ctx->HasInput("Ln1Variance"), "Input", "Ln1Variance",
240+
"FusedFeedForwardGrad");
241+
OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean",
242+
"FusedFeedForwardGrad");
243+
OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance",
244+
"FusedFeedForwardGrad");
245+
246+
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
247+
framework::GradVarName("Out"), "FusedFeedForwardGrad");
248+
249+
auto d_out_dim = ctx->GetInputDim(framework::GradVarName("Out"));
250+
ctx->SetOutputDim(framework::GradVarName("X"), d_out_dim);
251+
if (ctx->HasOutput(framework::GradVarName("Ln1Scale"))) {
252+
ctx->SetOutputDim(framework::GradVarName("Ln1Scale"),
253+
ctx->GetInputDim("Ln1Scale"));
254+
}
255+
if (ctx->HasOutput(framework::GradVarName("Ln1Bias"))) {
256+
ctx->SetOutputDim(framework::GradVarName("Ln1Bias"),
257+
ctx->GetInputDim("Ln1Bias"));
258+
}
259+
if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) {
260+
ctx->SetOutputDim(framework::GradVarName("Ln2Scale"),
261+
ctx->GetInputDim("Ln2Scale"));
262+
}
263+
if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) {
264+
ctx->SetOutputDim(framework::GradVarName("Ln2Bias"),
265+
ctx->GetInputDim("Ln2Bias"));
266+
}
267+
ctx->SetOutputDim(framework::GradVarName("Linear1Weight"),
268+
ctx->GetInputDim("Linear1Weight"));
269+
if (ctx->HasOutput(framework::GradVarName("Linear1Bias"))) {
270+
ctx->SetOutputDim(framework::GradVarName("Linear1Bias"),
271+
ctx->GetInputDim("Linear1Bias"));
272+
}
273+
ctx->SetOutputDim(framework::GradVarName("Linear2Weight"),
274+
ctx->GetInputDim("Linear2Weight"));
275+
if (ctx->HasOutput(framework::GradVarName("Linear2Bias"))) {
276+
ctx->SetOutputDim(framework::GradVarName("Linear2Bias"),
277+
ctx->GetInputDim("Linear2Bias"));
278+
}
279+
}
280+
281+
framework::OpKernelType GetExpectedKernelType(
282+
const framework::ExecutionContext &ctx) const override {
283+
auto input = ctx.Input<Tensor>("X");
284+
auto input_data_type = input->type();
285+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
286+
}
287+
};
288+
289+
template <typename T>
290+
class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> {
291+
public:
292+
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
293+
294+
protected:
295+
void Apply(GradOpPtr<T> op) const override {
296+
op->SetType("fused_feedforward_grad");
297+
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
298+
op->SetInput("X", this->Input("X"));
299+
op->SetInput("Linear1Weight", this->Input("Linear1Weight"));
300+
op->SetInput("Linear1Bias", this->Input("Linear1Bias"));
301+
op->SetInput("Linear2Weight", this->Input("Linear2Weight"));
302+
op->SetInput("Ln1Scale", this->Input("Ln1Scale"));
303+
op->SetInput("Ln1Bias", this->Input("Ln1Bias"));
304+
op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
305+
op->SetInput("Ln2Bias", this->Input("Ln2Bias"));
306+
op->SetInput("Dropout1Mask", this->Output("Dropout1Mask"));
307+
op->SetInput("Dropout2Mask", this->Output("Dropout2Mask"));
308+
op->SetInput("Linear1Out", this->Output("Linear1Out"));
309+
op->SetInput("Ln1Out", this->Output("Ln1Out"));
310+
op->SetInput("Ln1Mean", this->Output("Ln1Mean"));
311+
op->SetInput("Ln1Variance", this->Output("Ln1Variance"));
312+
op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
313+
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
314+
op->SetInput("Dropout1Out", this->Output("Dropout1Out"));
315+
op->SetInput("Dropout2Out", this->Output("Dropout2Out"));
316+
317+
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
318+
op->SetOutput(framework::GradVarName("Ln1Scale"),
319+
this->InputGrad("Ln1Scale"));
320+
op->SetOutput(framework::GradVarName("Ln1Bias"),
321+
this->InputGrad("Ln1Bias"));
322+
op->SetOutput(framework::GradVarName("Ln2Scale"),
323+
this->InputGrad("Ln2Scale"));
324+
op->SetOutput(framework::GradVarName("Ln2Bias"),
325+
this->InputGrad("Ln2Bias"));
326+
op->SetOutput(framework::GradVarName("Linear1Weight"),
327+
this->InputGrad("Linear1Weight"));
328+
op->SetOutput(framework::GradVarName("Linear1Bias"),
329+
this->InputGrad("Linear1Bias"));
330+
op->SetOutput(framework::GradVarName("Linear2Weight"),
331+
this->InputGrad("Linear2Weight"));
332+
if (this->HasInput("Linear2Bias")) {
333+
op->SetInput("Linear2Bias", this->Input("Linear2Bias"));
334+
op->SetOutput(framework::GradVarName("Linear2Bias"),
335+
this->InputGrad("Linear2Bias"));
336+
}
337+
338+
op->SetAttrMap(this->Attrs());
339+
}
340+
};
341+
342+
template <typename T>
343+
class FusedFeedForwardOpDoubleGradMaker
344+
: public framework::SingleGradOpMaker<T> {
345+
public:
346+
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
347+
348+
protected:
349+
void Apply(GradOpPtr<T> grad_op) const override {}
350+
};
209351
} // namespace operators
210352
} // namespace paddle
211353

212354
namespace ops = paddle::operators;
213355
REGISTER_OPERATOR(fused_feedforward, ops::FusedFeedForwardOp,
214-
ops::FusedFeedForwardOpMaker);
356+
ops::FusedFeedForwardOpMaker,
357+
ops::FusedFeedForwardOpGradMaker<paddle::framework::OpDesc>,
358+
ops::FusedFeedForwardOpGradMaker<paddle::imperative::OpBase>);
359+
REGISTER_OPERATOR(fused_feedforward_grad, ops::FusedFeedForwardOpGrad);

paddle/fluid/operators/fused/fused_feedforward_op.cu

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,210 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
171171
}
172172
};
173173

174+
template <typename DeviceContext, typename T>
175+
class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
176+
public:
177+
void MatMulGrad(const platform::CUDADeviceContext& ctx,
178+
const framework::Tensor& d_out, const framework::Tensor& a,
179+
const framework::Tensor& b, framework::Tensor* d_a,
180+
framework::Tensor* d_b) const {
181+
auto blas = math::GetBlas<DeviceContext, T>(ctx);
182+
auto a_2d = FoldInitDims(a);
183+
auto b_2d = FoldInitDims(b);
184+
auto mat_dim_a = math::CreateMatrixDescriptor(a_2d.dims(), 0, true);
185+
auto mat_dim_b = math::CreateMatrixDescriptor(b_2d.dims(), 0, true);
186+
auto mat_dim_dout = math::CreateMatrixDescriptor(d_out.dims(), 0, false);
187+
T alpha = static_cast<T>(1.0);
188+
blas.MatMul(d_out, mat_dim_dout, b, mat_dim_b, alpha, d_a, T(0));
189+
blas.MatMul(a, mat_dim_a, d_out, mat_dim_dout, alpha, d_b, T(0));
190+
}
191+
192+
void FFNGrad(
193+
const framework::Tensor& d_out, const framework::Tensor& x,
194+
const framework::Tensor& dropout1_mask,
195+
const framework::Tensor& dropout2_mask,
196+
const framework::Tensor& linear1_out, const framework::Tensor& ln1_out,
197+
const framework::Tensor& dropout1_out,
198+
const framework::Tensor& dropout2_out,
199+
const framework::Tensor& linear1_weight,
200+
const framework::Tensor* linear1_bias,
201+
const framework::Tensor& linear2_weight,
202+
const framework::Tensor* ln1_gamma, const framework::Tensor* ln1_beta,
203+
const framework::Tensor& ln1_mean, const framework::Tensor& ln1_variance,
204+
const framework::Tensor* ln2_gamma, const framework::Tensor* ln2_beta,
205+
const framework::Tensor& ln2_mean, const framework::Tensor& ln2_variance,
206+
framework::Tensor* d_x, framework::Tensor* d_linear1_weight,
207+
framework::Tensor* d_linear1_bias, framework::Tensor* d_linear2_weight,
208+
framework::Tensor* d_linear2_bias, framework::Tensor* d_ln1_gamma,
209+
framework::Tensor* d_ln1_beta, framework::Tensor* d_ln2_gamma,
210+
framework::Tensor* d_ln2_beta, const int bsz_seq, const int d_model,
211+
const int dim_feedforward, const DropoutParam& dropout_param1,
212+
const DropoutParam& dropout_param2, const std::string& act_method,
213+
const bool pre_layer_norm, const float epsilon1, const float epsilon2,
214+
const platform::CUDADeviceContext& ctx) const {
215+
FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
216+
bsz_seq, d_model, epsilon1);
217+
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
218+
ctx, bsz_seq, dim_feedforward, dropout_param1);
219+
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
220+
ctx, bsz_seq, d_model, dropout_param2, epsilon2);
221+
222+
auto place = ctx.GetPlace();
223+
using U = LayerNormParamType<T>;
224+
const U* ln1_gamma_ptr =
225+
ln1_gamma == nullptr ? nullptr : ln1_gamma->data<U>();
226+
const U* ln1_beta_ptr = ln1_beta == nullptr ? nullptr : ln1_beta->data<U>();
227+
const U* ln2_gamma_ptr =
228+
ln2_gamma == nullptr ? nullptr : ln2_gamma->data<U>();
229+
const U* ln2_beta_ptr = ln2_beta == nullptr ? nullptr : ln2_beta->data<U>();
230+
const T* linear1_bias_ptr =
231+
linear1_bias == nullptr ? nullptr : linear1_bias->data<T>();
232+
T* d_linear1_bias_ptr =
233+
d_linear1_bias == nullptr ? nullptr : d_linear1_bias->data<T>();
234+
T* d_linear2_bias_ptr =
235+
d_linear2_bias == nullptr ? nullptr : d_linear2_bias->data<T>();
236+
U* d_ln1_gamma_ptr =
237+
d_ln1_gamma == nullptr ? nullptr : d_ln1_gamma->data<U>();
238+
U* d_ln1_beta_ptr = d_ln1_beta == nullptr ? nullptr : d_ln1_beta->data<U>();
239+
U* d_ln2_gamma_ptr =
240+
d_ln2_gamma == nullptr ? nullptr : d_ln2_gamma->data<U>();
241+
U* d_ln2_beta_ptr = d_ln2_beta == nullptr ? nullptr : d_ln2_beta->data<U>();
242+
243+
framework::Tensor d_linear2_out, d_dropout2_out, d_residual;
244+
d_linear2_out.mutable_data<T>({bsz_seq, d_model}, place);
245+
d_dropout2_out.mutable_data<T>({bsz_seq, d_model}, place);
246+
d_residual.mutable_data<T>({bsz_seq, d_model}, place);
247+
248+
if (pre_layer_norm) {
249+
fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
250+
ctx, d_out.data<T>(), dropout2_mask.data<uint8_t>(),
251+
d_linear2_out.data<T>(), d_residual.data<T>(), d_linear2_bias_ptr);
252+
} else {
253+
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
254+
ctx, d_out.data<T>(), dropout2_out.data<T>(),
255+
dropout2_mask.data<uint8_t>(), ln2_gamma_ptr, ln2_mean.data<U>(),
256+
ln2_variance.data<U>(), d_dropout2_out.data<T>(), d_ln2_gamma_ptr,
257+
d_ln2_beta_ptr, d_linear2_out.data<T>(), d_linear2_bias_ptr,
258+
d_residual.data<T>());
259+
}
260+
261+
framework::Tensor d_dropout1_out;
262+
d_dropout1_out.mutable_data<T>({bsz_seq, dim_feedforward}, place);
263+
MatMulGrad(ctx, d_linear2_out, dropout1_out, linear2_weight,
264+
&d_dropout1_out, d_linear2_weight);
265+
266+
framework::Tensor d_linear1_out;
267+
d_linear1_out.mutable_data<T>({bsz_seq, dim_feedforward}, place);
268+
fused_act_dropout_helper.DropoutActBiasGrad(
269+
ctx, d_dropout1_out.data<T>(), linear1_out.data<T>(), linear1_bias_ptr,
270+
dropout1_mask.data<uint8_t>(), d_linear1_out.data<T>(),
271+
d_linear1_bias_ptr, act_method);
272+
273+
if (pre_layer_norm) {
274+
framework::Tensor d_ln1_out;
275+
d_ln1_out.mutable_data<T>({bsz_seq, d_model}, place);
276+
MatMulGrad(ctx, d_linear1_out, ln1_out, linear1_weight, &d_ln1_out,
277+
d_linear1_weight);
278+
279+
pre_layernorm_helper.LayerNormGrad(ctx, d_ln1_out.data<T>(), x.data<T>(),
280+
ln1_gamma_ptr, ln1_mean.data<U>(),
281+
ln1_variance.data<U>(), d_x->data<T>(),
282+
d_ln1_gamma_ptr, d_ln1_beta_ptr);
283+
} else {
284+
MatMulGrad(ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight);
285+
}
286+
}
287+
288+
void Compute(const framework::ExecutionContext& context) const override {
289+
using U = LayerNormParamType<T>;
290+
auto d_out =
291+
*context.Input<framework::Tensor>(framework::GradVarName("Out"));
292+
auto x = *context.Input<framework::Tensor>("X");
293+
auto dropout1_mask = *context.Input<framework::Tensor>("Dropout1Mask");
294+
auto dropout2_mask = *context.Input<framework::Tensor>("Dropout2Mask");
295+
auto linear1_out = *context.Input<framework::Tensor>("Linear1Out");
296+
auto ln1_out = *context.Input<framework::Tensor>("Ln1Out");
297+
auto dropout1_out = *context.Input<framework::Tensor>("Dropout1Out");
298+
auto dropout2_out = *context.Input<framework::Tensor>("Dropout2Out");
299+
auto linear1_weight = *context.Input<framework::Tensor>("Linear1Weight");
300+
auto* linear1_bias = context.Input<framework::Tensor>("Linear1Bias");
301+
auto linear2_weight = *context.Input<framework::Tensor>("Linear2Weight");
302+
auto ln1_mean = *context.Input<framework::Tensor>("Ln1Mean");
303+
auto ln1_variance = *context.Input<framework::Tensor>("Ln1Variance");
304+
auto* ln1_scale = context.Input<framework::Tensor>("Ln1Scale");
305+
auto* ln1_bias = context.Input<framework::Tensor>("Ln1Bias");
306+
auto ln2_mean = *context.Input<framework::Tensor>("Ln2Mean");
307+
auto ln2_variance = *context.Input<framework::Tensor>("Ln2Variance");
308+
auto* ln2_scale = context.Input<framework::Tensor>("Ln2Scale");
309+
auto* ln2_bias = context.Input<framework::Tensor>("Ln2Bias");
310+
311+
auto* d_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
312+
auto* d_ln1_scale =
313+
context.Output<framework::Tensor>(framework::GradVarName("Ln1Scale"));
314+
auto* d_ln1_bias =
315+
context.Output<framework::Tensor>(framework::GradVarName("Ln1Bias"));
316+
auto* d_ln2_scale =
317+
context.Output<framework::Tensor>(framework::GradVarName("Ln2Scale"));
318+
auto* d_ln2_bias =
319+
context.Output<framework::Tensor>(framework::GradVarName("Ln2Bias"));
320+
auto* d_linear1_weight = context.Output<framework::Tensor>(
321+
framework::GradVarName("Linear1Weight"));
322+
auto* d_linear1_bias = context.Output<framework::Tensor>(
323+
framework::GradVarName("Linear1Bias"));
324+
auto* d_linear2_weight = context.Output<framework::Tensor>(
325+
framework::GradVarName("Linear2Weight"));
326+
auto* d_linear2_bias = context.Output<framework::Tensor>(
327+
framework::GradVarName("Linear2Bias"));
328+
329+
const float epsilon1 = context.Attr<float>("ln1_epsilon");
330+
const float epsilon2 = context.Attr<float>("ln2_epsilon");
331+
const bool pre_layer_norm = context.Attr<bool>("pre_layer_norm");
332+
const std::string act_method = context.Attr<std::string>("act_method");
333+
DropoutParam dropout_param1(context, 1);
334+
DropoutParam dropout_param2(context, 2);
335+
336+
auto place = context.GetPlace();
337+
d_x->mutable_data<T>(place);
338+
if (d_ln1_scale) {
339+
d_ln1_scale->mutable_data<U>(place);
340+
}
341+
if (d_ln1_bias) {
342+
d_ln1_bias->mutable_data<U>(place);
343+
}
344+
if (d_ln2_scale) {
345+
d_ln2_scale->mutable_data<U>(place);
346+
}
347+
if (d_ln2_bias) {
348+
d_ln2_bias->mutable_data<U>(place);
349+
}
350+
if (d_linear1_bias) {
351+
d_linear1_bias->mutable_data<T>(place);
352+
}
353+
if (d_linear2_bias) {
354+
d_linear2_bias->mutable_data<T>(place);
355+
}
356+
d_linear1_weight->mutable_data<T>(place);
357+
d_linear2_weight->mutable_data<T>(place);
358+
359+
auto x_dim = x.dims();
360+
auto mat_dim_x =
361+
math::CreateMatrixDescriptor(RowMatrixFromVector(x_dim), 0, false);
362+
363+
auto linear1_weight_dim = linear1_weight.dims();
364+
int d_model = linear1_weight_dim[0];
365+
int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1];
366+
int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_;
367+
368+
FFNGrad(d_out, x, dropout1_mask, dropout2_mask, linear1_out, ln1_out,
369+
dropout1_out, dropout2_out, linear1_weight, linear1_bias,
370+
linear2_weight, ln1_scale, ln1_bias, ln1_mean, ln1_variance,
371+
ln2_scale, ln2_bias, ln2_mean, ln2_variance, d_x, d_linear1_weight,
372+
d_linear1_bias, d_linear2_weight, d_linear2_bias, d_ln1_scale,
373+
d_ln1_bias, d_ln2_scale, d_ln2_bias, bsz_seq, d_model,
374+
dim_feedforward, dropout_param1, dropout_param2, act_method,
375+
pre_layer_norm, epsilon1, epsilon2, context.cuda_device_context());
376+
}
377+
};
174378
} // namespace operators
175379
} // namespace paddle
176380

@@ -181,3 +385,10 @@ REGISTER_OP_CUDA_KERNEL(
181385
ops::FusedFeedForwardKernel<paddle::platform::CUDADeviceContext, double>,
182386
ops::FusedFeedForwardKernel<paddle::platform::CUDADeviceContext,
183387
paddle::platform::float16>);
388+
REGISTER_OP_CUDA_KERNEL(
389+
fused_feedforward_grad,
390+
ops::FusedFeedForwardGradKernel<paddle::platform::CUDADeviceContext, float>,
391+
ops::FusedFeedForwardGradKernel<paddle::platform::CUDADeviceContext,
392+
double>,
393+
ops::FusedFeedForwardGradKernel<paddle::platform::CUDADeviceContext,
394+
paddle::platform::float16>);

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ foreach(TEST_OP ${MIXED_DIST_TEST_OPS})
9191
endforeach()
9292

9393
if(NOT WITH_GPU)
94-
9594
LIST(REMOVE_ITEM TEST_OPS test_fused_feedforward_op)
9695
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op)
9796
endif()

0 commit comments

Comments
 (0)