Skip to content

Commit 971bab0

Browse files
committed
Delete grad Op
1 parent 7b1aa3c commit 971bab0

File tree

2 files changed

+1
-87
lines changed

2 files changed

+1
-87
lines changed

paddle/fluid/operators/eigvals_op.cc

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -73,63 +73,17 @@ class EigvalsOpVarTypeInference : public framework::VarTypeInference {
7373
ctx->SetOutputDataType("Out", output_dtype);
7474
}
7575
};
76-
77-
class EigvalsGradOp : public framework::OperatorWithKernel {
78-
public:
79-
using framework::OperatorWithKernel::OperatorWithKernel;
80-
void InferShape(framework::InferShapeContext* ctx) const override {
81-
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "EigvalsGrad");
82-
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
83-
"Out@Grad", "EigvalsGrad");
84-
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
85-
"X@Grad", "EigvalsGrad");
86-
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
87-
}
88-
89-
protected:
90-
framework::OpKernelType GetExpectedKernelType(
91-
const framework::ExecutionContext& ctx) const override {
92-
return framework::OpKernelType(
93-
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
94-
}
95-
};
96-
97-
template <typename T>
98-
class EigvalsGradOpMaker : public framework::SingleGradOpMaker<T> {
99-
public:
100-
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
101-
102-
protected:
103-
void Apply(GradOpPtr<T> retv) const override {
104-
retv->SetType("eigvals_grad");
105-
retv->SetInput("X", this->Input("X"));
106-
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
107-
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
108-
}
109-
};
11076
} // namespace operators
11177
} // namespace paddle
11278
namespace ops = paddle::operators;
11379
namespace plat = paddle::platform;
11480

11581
REGISTER_OPERATOR(eigvals, ops::EigvalsOp, ops::EigvalsOpMaker,
116-
ops::EigvalsOpVarTypeInference,
117-
ops::EigvalsGradOpMaker<paddle::framework::OpDesc>,
118-
ops::EigvalsGradOpMaker<paddle::imperative::OpBase>);
119-
REGISTER_OPERATOR(eigvals_grad, ops::EigvalsGradOp);
82+
ops::EigvalsOpVarTypeInference);
12083
REGISTER_OP_CPU_KERNEL(eigvals,
12184
ops::EigvalsKernel<plat::CPUDeviceContext, float>,
12285
ops::EigvalsKernel<plat::CPUDeviceContext, double>,
12386
ops::EigvalsKernel<plat::CPUDeviceContext,
12487
paddle::platform::complex<float>>,
12588
ops::EigvalsKernel<plat::CPUDeviceContext,
12689
paddle::platform::complex<double>>);
127-
128-
// TODO(Ruibiao): Support gradient kernel for Eigvals OP
129-
// REGISTER_OP_CPU_KERNEL(eigvals_grad,
130-
// ops::EigvalsGradKernel<plat::CPUDeviceContext, float>,
131-
// ops::EigvalsGradKernel<plat::CPUDeviceContext, double>,
132-
// ops::EigvalsGradKernel<plat::CPUDeviceContext,
133-
// paddle::platform::complex<float>>,
134-
// ops::EigvalsGradKernel<plat::CPUDeviceContext,
135-
// paddle::platform::complex<double>>);

python/paddle/fluid/tests/unittests/test_eigvals_op.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,6 @@ def np_eigvals(a):
3232
return res
3333

3434

35-
''' Keep it for testing gradient kernel in the future.
36-
def np_eigvals_grad(a, out_grad):
37-
l, v = np.linalg.eig(a)
38-
print("l:")
39-
print(l)
40-
print("v:")
41-
print(v)
42-
vh = v.conj().T
43-
print("vh:")
44-
print(vh)
45-
print("out_grad:")
46-
print(out_grad)
47-
a_grad = np.linalg.solve(vh, np.diagflat(out_grad, 0) * vh)
48-
print("a_grad")
49-
print(a_grad)
50-
51-
return a_grad.astype(a.dtype)
52-
'''
53-
54-
5535
class TestEigvalsOp(OpTest):
5636
def setUp(self):
5737
np.random.seed(0)
@@ -86,26 +66,6 @@ def test_check_output(self):
8666
self.check_output_with_place_customized(
8767
checker=self.verify_output, place=core.CPUPlace())
8868

89-
''' The gradient kernel of this operator does not yet develop.
90-
def test_check_grad_normal(self):
91-
self.grad_dtype = self.dtype
92-
if self.dtype == np.float32:
93-
self.grad_dtype = np.complex64
94-
elif self.dtype == np.float64:
95-
self.grad_dtype = np.complex128
96-
97-
self.out_grad = (np.random.random(self.input_dims[-1:]) +
98-
np.random.random(self.input_dims[-1:]) * 1j).astype(self.grad_dtype)
99-
self.x_grad = np_eigvals_grad(self.input_data, self.out_grad)
100-
101-
print("np_eigvals_grad:\n")
102-
print(self.x_grad)
103-
104-
self.check_grad(['X'], 'Out',
105-
user_defined_grads=[self.x_grad],
106-
user_defined_grad_outputs=[self.out_grad])
107-
'''
108-
10969
def verify_output(self, outs):
11070
actual_outs = np.sort(np.array(outs[0]))
11171
expect_outs = np.sort(np.array(self.outputs['Out']))

0 commit comments

Comments
 (0)