Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions paddle/fluid/operators/atan2_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/atan2_op.h"

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

namespace paddle {
namespace operators {

class Atan2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X1"), "Input", "X1", "atan2");
OP_INOUT_CHECK(ctx->HasInput("X2"), "Input", "X2", "atan2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "atan2");

auto in_dims = ctx->GetInputDim("X1");

ctx->SetOutputDim("Out", in_dims);
}
};

class Atan2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X1", "(Tensor), The input tensor of atan2 op.");
AddInput("X2", "(Tensor), The input tensor of atan2 op.");
AddOutput("Out", "(Tensor), The output tensor of atan2 op.");
AddComment(R"DOC(
Atan2 Operator.

This operator is used to perform elementwise atan2 for input $X1$, $X2$.
$$out = atan2(x1, x2)$$

)DOC");
}
};

class Atan2GradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X1"), "Input", "X1", "Atan2Grad");
OP_INOUT_CHECK(ctx->HasInput("X2"), "Input", "X2", "Atan2Grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@Grad", "Atan2Grad");

auto x1_grad_name = framework::GradVarName("X1");
auto x2_grad_name = framework::GradVarName("X2");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));

if (ctx->HasOutput(x1_grad_name)) {
ctx->SetOutputDim(framework::GradVarName("X1"), dout_dims);
}
if (ctx->HasOutput(x2_grad_name)) {
ctx->SetOutputDim(framework::GradVarName("X2"), dout_dims);
}
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X1");
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};

template <typename T>
class Atan2GradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

void Apply(GradOpPtr<T> retv) const override {
retv->SetType("atan2_grad");
retv->SetInput("X1", this->Input("X1"));
retv->SetInput("X2", this->Input("X2"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("X1"), this->InputGrad("X1"));
retv->SetOutput(framework::GradVarName("X2"), this->InputGrad("X2"));
}
};

class Atan2OpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
auto type = ctx->GetInputDataType("X1");
if (ctx->GetInputDataType("X1") == framework::proto::VarType::INT32 ||
ctx->GetInputDataType("X1") == framework::proto::VarType::INT64 ||
ctx->GetInputDataType("X2") == framework::proto::VarType::INT32 ||
ctx->GetInputDataType("X2") == framework::proto::VarType::INT64) {
type = framework::proto::VarType::FP64;
}
ctx->SetOutputDataType("Out", type);
}
};
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(atan2, ops::Atan2Op, ops::Atan2OpMaker,
ops::Atan2GradMaker<paddle::framework::OpDesc>,
ops::Atan2GradMaker<paddle::imperative::OpBase>,
ops::Atan2OpVarTypeInference);

REGISTER_OPERATOR(atan2_grad, ops::Atan2GradOp);

REGISTER_OP_CPU_KERNEL(
atan2, ops::Atan2Kernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::Atan2Kernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::Atan2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::Atan2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::Atan2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>);

REGISTER_OP_CPU_KERNEL(
atan2_grad, ops::Atan2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::Atan2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Atan2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么不支持int类型的输入?

Copy link
Contributor Author

@ronny1996 ronny1996 Jun 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

调查了tensorflow和pytorch,tensorflow不支持int;pytorch前向支持int,且输入int时输出为fp64,反向不支持
更新成和pytorch一致

31 changes: 31 additions & 0 deletions paddle/fluid/operators/atan2_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/atan2_op.h"

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
atan2, ops::Atan2Kernel<paddle::platform::CUDADeviceContext, int32_t>,
ops::Atan2Kernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::Atan2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::Atan2Kernel<paddle::platform::CUDADeviceContext, double>,
ops::Atan2Kernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);

REGISTER_OP_CUDA_KERNEL(
atan2_grad,
ops::Atan2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::Atan2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Atan2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
168 changes: 168 additions & 0 deletions paddle/fluid/operators/atan2_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/for_range.h"

namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using framework::To32BitIndex;

template <typename T>
struct Atan2Out {
using type = T;
};

template <>
struct Atan2Out<int32_t> {
using type = double;
};

template <>
struct Atan2Out<int64_t> {
using type = double;
};

template <typename T>
struct Atan2Functor {
Atan2Functor(const T* x1, const T* x2, typename Atan2Out<T>::type* out,
int64_t numel)
: x1_(x1), x2_(x2), out_(out), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
out_[idx] = static_cast<typename Atan2Out<T>::type>(
::atan2f(static_cast<float>(x1_[idx]), static_cast<float>(x2_[idx])));
}

const T* x1_;
const T* x2_;
typename Atan2Out<T>::type* out_;
int64_t numel_;
};

template <>
struct Atan2Functor<double> {
Atan2Functor(const double* x1, const double* x2, double* out, int64_t numel)
: x1_(x1), x2_(x2), out_(out), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
out_[idx] = ::atan2(x1_[idx], x2_[idx]);
}

const double* x1_;
const double* x2_;
double* out_;
int64_t numel_;
};

// dx1 = dout * x2 / ((x1)^2 + (x2)^2)
// dx2 = - dout * x1 / ((x1)^2 + (x2)^2)
template <typename T>
struct Atan2GradFunctor {
Atan2GradFunctor(const T* x1, const T* x2, const T* dout, T* dx1, T* dx2,
int64_t numel)
: x1_(x1), x2_(x2), dout_(dout), dx1_(dx1), dx2_(dx2), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
float x1 = static_cast<float>(x1_[idx]);
float x2 = static_cast<float>(x2_[idx]);
float x = x1 * x1 + x2 * x2;
dx1_[idx] = static_cast<T>(static_cast<float>(dout_[idx]) * x2 / x);
dx2_[idx] = static_cast<T>(-static_cast<float>(dout_[idx]) * x1 / x);
}

const T* x1_;
const T* x2_;
const T* dout_;
T* dx1_;
T* dx2_;
int64_t numel_;
};

template <>
struct Atan2GradFunctor<double> {
Atan2GradFunctor(const double* x1, const double* x2, const double* dout,
double* dx1, double* dx2, int64_t numel)
: x1_(x1), x2_(x2), dout_(dout), dx1_(dx1), dx2_(dx2), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
auto x = x1_[idx] * x1_[idx] + x2_[idx] * x2_[idx];
dx1_[idx] = dout_[idx] * x2_[idx] / x;
dx2_[idx] = -dout_[idx] * x1_[idx] / x;
}

const double* x1_;
const double* x2_;
const double* dout_;
double* dx1_;
double* dx2_;
int64_t numel_;
};

template <typename DeviceContext, typename T>
class Atan2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* X1 = context.Input<Tensor>("X1");
const Tensor* X2 = context.Input<Tensor>("X2");
Tensor* Out = context.Output<Tensor>("Out");

auto numel = X1->numel();
auto x1 = X1->data<T>();
auto x2 = X2->data<T>();
auto out = Out->mutable_data<typename Atan2Out<T>::type>(
context.GetPlace(), size_t(numel * sizeof(typename Atan2Out<T>::type)));
auto& dev_ctx = context.template device_context<DeviceContext>();

platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
Atan2Functor<T> functor(x1, x2, out, numel);
for_range(functor);
}
};

template <typename DeviceContext, typename T>
class Atan2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const {
const Tensor* X1 = context.Input<Tensor>("X1");
const Tensor* X2 = context.Input<Tensor>("X2");
const Tensor* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* dX1 = context.Output<Tensor>(framework::GradVarName("X1"));
Tensor* dX2 = context.Output<Tensor>(framework::GradVarName("X2"));

auto numel = X1->numel();
auto x1 = X1->data<T>();
auto x2 = X2->data<T>();
auto dout = dOut->data<T>();
auto dx1 =
dX1->mutable_data<T>(context.GetPlace(), size_t(numel * sizeof(T)));
auto dx2 =
dX2->mutable_data<T>(context.GetPlace(), size_t(numel * sizeof(T)));
auto& dev_ctx = context.template device_context<DeviceContext>();

platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
Atan2GradFunctor<T> functor(x1, x2, dout, dx1, dx2, numel);
for_range(functor);
}
};
} // namespace operators
} // namespace paddle
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
from .tensor.math import acos # noqa: F401
from .tensor.math import asin # noqa: F401
from .tensor.math import atan # noqa: F401
from .tensor.math import atan2 # noqa: F401
Copy link
Contributor

@qili93 qili93 May 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor/__init__.py里面也需要添加from .math import atan

from .tensor.math import ceil # noqa: F401
from .tensor.math import cos # noqa: F401
from .tensor.math import tan # noqa: F401
Expand Down Expand Up @@ -425,6 +426,7 @@
'divide',
'ceil',
'atan',
'atan2',
'expand',
'broadcast_to',
'ones_like',
Expand Down
Loading