Skip to content

Commit b3787d1

Browse files
authored
add the matmul v2 grad kernel
* add the matmul v2 grad kernel * relief the test case time * update the test case for the matmul double grad * remove the unsed code for the matmul double grad * update the test case for the double grad matmul * remove the unused code in dot
1 parent c727ec4 commit b3787d1

File tree

4 files changed

+808
-20
lines changed

4 files changed

+808
-20
lines changed

paddle/fluid/operators/matmul_v2_op.cc

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,59 @@ class MatMulV2GradOpMaker : public framework::SingleGradOpMaker<T> {
228228
}
229229
};
230230

231+
class MatMulV2OpDoubleGrad : public framework::OperatorWithKernel {
232+
public:
233+
using framework::OperatorWithKernel::OperatorWithKernel;
234+
235+
protected:
236+
void InferShape(framework::InferShapeContext* context) const override {
237+
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "matmul");
238+
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul");
239+
OP_INOUT_CHECK(context->HasInput("DOut"), "Input", "DOut", "matmul");
240+
241+
if (context->HasOutput("DX") && context->HasInput("DDY")) {
242+
context->ShareDim("X", "DX");
243+
}
244+
245+
if (context->HasOutput("DY") && context->HasInput("DDX")) {
246+
context->ShareDim("Y", "DY");
247+
}
248+
249+
if (context->HasOutput("DDOut") &&
250+
(context->HasInput("DDY") || context->HasInput("DDX"))) {
251+
context->ShareDim("DOut", "DDOut");
252+
}
253+
}
254+
};
255+
256+
template <typename T>
257+
class MatMulV2OpDoubleGradMaker : public framework::SingleGradOpMaker<T> {
258+
public:
259+
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
260+
261+
protected:
262+
void Apply(GradOpPtr<T> op) const override {
263+
op->SetType("matmul_v2_grad_grad");
264+
op->SetInput("X", this->Input("X"));
265+
op->SetInput("Y", this->Input("Y"));
266+
op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
267+
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
268+
op->SetInput("DDY", this->OutputGrad(framework::GradVarName("Y")));
269+
270+
auto ddx = this->OutputGrad(framework::GradVarName("X"));
271+
auto ddy = this->OutputGrad(framework::GradVarName("Y"));
272+
273+
if (!ddx.empty() || !ddy.empty()) {
274+
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
275+
}
276+
op->SetOutput("DX",
277+
ddy.empty() ? this->EmptyInputGrad() : this->InputGrad("X"));
278+
op->SetOutput("DY",
279+
ddx.empty() ? this->EmptyInputGrad() : this->InputGrad("Y"));
280+
281+
op->SetAttrMap(this->Attrs());
282+
}
283+
};
231284
} // namespace operators
232285
} // namespace paddle
233286

@@ -236,7 +289,11 @@ REGISTER_OPERATOR(matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker,
236289
ops::MatMulV2GradOpMaker<paddle::framework::OpDesc>,
237290
ops::MatMulV2GradOpMaker<paddle::imperative::OpBase>);
238291

239-
REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad);
292+
REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad,
293+
ops::MatMulV2OpDoubleGradMaker<paddle::framework::OpDesc>,
294+
ops::MatMulV2OpDoubleGradMaker<paddle::imperative::OpBase>);
295+
296+
REGISTER_OPERATOR(matmul_v2_grad_grad, ops::MatMulV2OpDoubleGrad);
240297

241298
REGISTER_OP_CPU_KERNEL(
242299
matmul_v2, ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, float>,
@@ -254,3 +311,11 @@ REGISTER_OP_CPU_KERNEL(
254311
paddle::platform::complex<float>>,
255312
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext,
256313
paddle::platform::complex<double>>);
314+
REGISTER_OP_CPU_KERNEL(
315+
matmul_v2_grad_grad,
316+
ops::MatMulV2DoubleGradKernel<paddle::platform::CPUDeviceContext, float>,
317+
ops::MatMulV2DoubleGradKernel<paddle::platform::CPUDeviceContext, double>,
318+
ops::MatMulV2DoubleGradKernel<paddle::platform::CPUDeviceContext,
319+
paddle::platform::complex<float>>,
320+
ops::MatMulV2DoubleGradKernel<paddle::platform::CPUDeviceContext,
321+
paddle::platform::complex<double>>);

paddle/fluid/operators/matmul_v2_op.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,13 @@ REGISTER_OP_CUDA_KERNEL(
3030
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::float16>,
3131
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::complex<float>>,
3232
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::complex<double>>);
33+
34+
REGISTER_OP_CUDA_KERNEL(
35+
matmul_v2_grad_grad,
36+
ops::MatMulV2DoubleGradKernel<paddle::platform::CUDADeviceContext, float>,
37+
ops::MatMulV2DoubleGradKernel<paddle::platform::CUDADeviceContext, double>,
38+
ops::MatMulV2DoubleGradKernel<plf::CUDADeviceContext, plf::float16>,
39+
ops::MatMulV2DoubleGradKernel<paddle::platform::CUDADeviceContext,
40+
paddle::platform::complex<float>>,
41+
ops::MatMulV2DoubleGradKernel<paddle::platform::CUDADeviceContext,
42+
paddle::platform::complex<double>>);

0 commit comments

Comments
 (0)