@@ -73,63 +73,17 @@ class EigvalsOpVarTypeInference : public framework::VarTypeInference {
7373 ctx->SetOutputDataType (" Out" , output_dtype);
7474 }
7575};
76-
77- class EigvalsGradOp : public framework ::OperatorWithKernel {
78- public:
79- using framework::OperatorWithKernel::OperatorWithKernel;
80- void InferShape (framework::InferShapeContext* ctx) const override {
81- OP_INOUT_CHECK (ctx->HasInput (" X" ), " Input" , " X" , " EigvalsGrad" );
82- OP_INOUT_CHECK (ctx->HasInput (framework::GradVarName (" Out" )), " Input" ,
83- " Out@Grad" , " EigvalsGrad" );
84- OP_INOUT_CHECK (ctx->HasOutput (framework::GradVarName (" X" )), " Output" ,
85- " X@Grad" , " EigvalsGrad" );
86- ctx->SetOutputDim (framework::GradVarName (" X" ), ctx->GetInputDim (" X" ));
87- }
88-
89- protected:
90- framework::OpKernelType GetExpectedKernelType (
91- const framework::ExecutionContext& ctx) const override {
92- return framework::OpKernelType (
93- OperatorWithKernel::IndicateVarDataType (ctx, " X" ), ctx.GetPlace ());
94- }
95- };
96-
97- template <typename T>
98- class EigvalsGradOpMaker : public framework ::SingleGradOpMaker<T> {
99- public:
100- using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
101-
102- protected:
103- void Apply (GradOpPtr<T> retv) const override {
104- retv->SetType (" eigvals_grad" );
105- retv->SetInput (" X" , this ->Input (" X" ));
106- retv->SetInput (framework::GradVarName (" Out" ), this ->OutputGrad (" Out" ));
107- retv->SetOutput (framework::GradVarName (" X" ), this ->InputGrad (" X" ));
108- }
109- };
11076} // namespace operators
11177} // namespace paddle
11278namespace ops = paddle::operators;
11379namespace plat = paddle::platform;
11480
11581REGISTER_OPERATOR (eigvals, ops::EigvalsOp, ops::EigvalsOpMaker,
116- ops::EigvalsOpVarTypeInference,
117- ops::EigvalsGradOpMaker<paddle::framework::OpDesc>,
118- ops::EigvalsGradOpMaker<paddle::imperative::OpBase>);
119- REGISTER_OPERATOR (eigvals_grad, ops::EigvalsGradOp);
82+ ops::EigvalsOpVarTypeInference);
12083REGISTER_OP_CPU_KERNEL (eigvals,
12184 ops::EigvalsKernel<plat::CPUDeviceContext, float >,
12285 ops::EigvalsKernel<plat::CPUDeviceContext, double >,
12386 ops::EigvalsKernel<plat::CPUDeviceContext,
12487 paddle::platform::complex <float >>,
12588 ops::EigvalsKernel<plat::CPUDeviceContext,
12689 paddle::platform::complex <double >>);
127-
128- // TODO(Ruibiao): Support gradient kernel for Eigvals OP
129- // REGISTER_OP_CPU_KERNEL(eigvals_grad,
130- // ops::EigvalsGradKernel<plat::CPUDeviceContext, float>,
131- // ops::EigvalsGradKernel<plat::CPUDeviceContext, double>,
132- // ops::EigvalsGradKernel<plat::CPUDeviceContext,
133- // paddle::platform::complex<float>>,
134- // ops::EigvalsGradKernel<plat::CPUDeviceContext,
135- // paddle::platform::complex<double>>);
0 commit comments