@@ -228,6 +228,59 @@ class MatMulV2GradOpMaker : public framework::SingleGradOpMaker<T> {
228228 }
229229};
230230
231+ class MatMulV2OpDoubleGrad : public framework ::OperatorWithKernel {
232+ public:
233+ using framework::OperatorWithKernel::OperatorWithKernel;
234+
235+ protected:
236+ void InferShape (framework::InferShapeContext* context) const override {
237+ OP_INOUT_CHECK (context->HasInput (" X" ), " Input" , " X" , " matmul" );
238+ OP_INOUT_CHECK (context->HasInput (" Y" ), " Input" , " Y" , " matmul" );
239+ OP_INOUT_CHECK (context->HasInput (" DOut" ), " Input" , " DOut" , " matmul" );
240+
241+ if (context->HasOutput (" DX" ) && context->HasInput (" DDY" )) {
242+ context->ShareDim (" X" , " DX" );
243+ }
244+
245+ if (context->HasOutput (" DY" ) && context->HasInput (" DDX" )) {
246+ context->ShareDim (" Y" , " DY" );
247+ }
248+
249+ if (context->HasOutput (" DDOut" ) &&
250+ (context->HasInput (" DDY" ) || context->HasInput (" DDX" ))) {
251+ context->ShareDim (" DOut" , " DDOut" );
252+ }
253+ }
254+ };
255+
256+ template <typename T>
257+ class MatMulV2OpDoubleGradMaker : public framework ::SingleGradOpMaker<T> {
258+ public:
259+ using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
260+
261+ protected:
262+ void Apply (GradOpPtr<T> op) const override {
263+ op->SetType (" matmul_v2_grad_grad" );
264+ op->SetInput (" X" , this ->Input (" X" ));
265+ op->SetInput (" Y" , this ->Input (" Y" ));
266+ op->SetInput (" DOut" , this ->Input (framework::GradVarName (" Out" )));
267+ op->SetInput (" DDX" , this ->OutputGrad (framework::GradVarName (" X" )));
268+ op->SetInput (" DDY" , this ->OutputGrad (framework::GradVarName (" Y" )));
269+
270+ auto ddx = this ->OutputGrad (framework::GradVarName (" X" ));
271+ auto ddy = this ->OutputGrad (framework::GradVarName (" Y" ));
272+
273+ if (!ddx.empty () || !ddy.empty ()) {
274+ op->SetOutput (" DDOut" , this ->InputGrad (framework::GradVarName (" Out" )));
275+ }
276+ op->SetOutput (" DX" ,
277+ ddy.empty () ? this ->EmptyInputGrad () : this ->InputGrad (" X" ));
278+ op->SetOutput (" DY" ,
279+ ddx.empty () ? this ->EmptyInputGrad () : this ->InputGrad (" Y" ));
280+
281+ op->SetAttrMap (this ->Attrs ());
282+ }
283+ };
231284} // namespace operators
232285} // namespace paddle
233286
@@ -236,7 +289,11 @@ REGISTER_OPERATOR(matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker,
236289 ops::MatMulV2GradOpMaker<paddle::framework::OpDesc>,
237290 ops::MatMulV2GradOpMaker<paddle::imperative::OpBase>);
238291
239- REGISTER_OPERATOR (matmul_v2_grad, ops::MatMulV2OpGrad);
292+ REGISTER_OPERATOR (matmul_v2_grad, ops::MatMulV2OpGrad,
293+ ops::MatMulV2OpDoubleGradMaker<paddle::framework::OpDesc>,
294+ ops::MatMulV2OpDoubleGradMaker<paddle::imperative::OpBase>);
295+
296+ REGISTER_OPERATOR (matmul_v2_grad_grad, ops::MatMulV2OpDoubleGrad);
240297
241298REGISTER_OP_CPU_KERNEL (
242299 matmul_v2, ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, float >,
@@ -254,3 +311,11 @@ REGISTER_OP_CPU_KERNEL(
254311 paddle::platform::complex <float >>,
255312 ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext,
256313 paddle::platform::complex <double >>);
314+ REGISTER_OP_CPU_KERNEL (
315+ matmul_v2_grad_grad,
316+ ops::MatMulV2DoubleGradKernel<paddle::platform::CPUDeviceContext, float >,
317+ ops::MatMulV2DoubleGradKernel<paddle::platform::CPUDeviceContext, double >,
318+ ops::MatMulV2DoubleGradKernel<paddle::platform::CPUDeviceContext,
319+ paddle::platform::complex <float >>,
320+ ops::MatMulV2DoubleGradKernel<paddle::platform::CPUDeviceContext,
321+ paddle::platform::complex <double >>);
0 commit comments