Skip to content

Commit 2938e07

Browse files
thisjiangzhhsplendid
authored andcommitted
add gradient kernel of det op and slogdet op (PaddlePaddle#36013)
* add gradient kernel of det op and slogdet op * fix CI APPROVAL problem
1 parent 928f937 commit 2938e07

File tree

4 files changed

+266
-75
lines changed

4 files changed

+266
-75
lines changed

paddle/fluid/operators/determinant_op.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class DeterminantGradOp : public framework::OperatorWithKernel {
4848
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input",
4949
"DeterminantGradOp");
5050
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "DeterminantGradOp");
51+
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
52+
framework::GradVarName("Out"), "DeterminantGradOp");
5153
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output",
5254
framework::GradVarName("Input"), "DeterminantGradOp");
5355

@@ -117,7 +119,8 @@ class SlogDeterminantGradOp : public framework::OperatorWithKernel {
117119
"SlogDeterminantGradOp");
118120
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out",
119121
"SlogDeterminantGradOp");
120-
122+
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
123+
framework::GradVarName("Out"), "SlogDeterminantGradOp");
121124
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output",
122125
framework::GradVarName("Input"), "SlogDeterminantGradOp");
123126

@@ -179,13 +182,13 @@ REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp,
179182
ops::SlogDeterminantGradOpMaker<paddle::imperative::OpBase>);
180183

181184
REGISTER_OPERATOR(slogdeterminant_grad,
182-
ops::DeterminantGradOp) // reuse det grad op
185+
ops::SlogDeterminantGradOp) // reuse det grad op
183186

184187
REGISTER_OP_CPU_KERNEL(
185188
slogdeterminant, ops::SlogDeterminantKernel<plat::CPUDeviceContext, float>,
186189
ops::SlogDeterminantKernel<plat::CPUDeviceContext, double>);
187190

188191
REGISTER_OP_CPU_KERNEL(
189192
slogdeterminant_grad,
190-
ops::DeterminantGradKernel<plat::CPUDeviceContext, float>,
191-
ops::DeterminantGradKernel<plat::CPUDeviceContext, double>);
193+
ops::SlogDeterminantGradKernel<plat::CPUDeviceContext, float>,
194+
ops::SlogDeterminantGradKernel<plat::CPUDeviceContext, double>);

paddle/fluid/operators/determinant_op.cu

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,42 +14,6 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/framework/op_registry.h"
1616
#include "paddle/fluid/operators/determinant_op.h"
17-
#include "paddle/fluid/platform/cuda_primitives.h"
18-
19-
namespace paddle {
20-
namespace operators {
21-
22-
using platform::PADDLE_CUDA_NUM_THREADS;
23-
using Tensor = framework::Tensor;
24-
25-
template <typename T>
26-
__global__ void DeterminantGrad(const size_t numel, T* out) {
27-
int tid = threadIdx.x + blockIdx.x * blockDim.x;
28-
if (tid < numel) {
29-
out[tid] = static_cast<T>(1);
30-
}
31-
}
32-
33-
template <typename T>
34-
class DeterminantGradCUDAKernel : public framework::OpKernel<T> {
35-
public:
36-
void Compute(const framework::ExecutionContext& context) const override {
37-
const auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
38-
const T* dout_data = dout->data<T>();
39-
auto dout_dim = vectorize(dout->dims());
40-
41-
auto* dx = context.Output<Tensor>(framework::GradVarName("Input"));
42-
T* dx_data = dx->mutable_data<T>(context.GetPlace());
43-
44-
int64_t numel = dx->numel();
45-
for (int64_t idx = 0; idx < numel; idx++) {
46-
dx_data[idx] = static_cast<T>(1);
47-
}
48-
}
49-
};
50-
51-
} // namespace operators
52-
} // namespace paddle
5317

5418
namespace ops = paddle::operators;
5519
namespace plat = paddle::platform;

0 commit comments

Comments
 (0)