From 53c24487b3d270201fa8cfa2c50f307a77ad1ee7 Mon Sep 17 00:00:00 2001 From: jeff41404 Date: Thu, 12 Aug 2021 06:39:51 +0000 Subject: [PATCH] add fft c2c cufft kernel --- paddle/fluid/operators/spectral_op.cu | 85 +++++++++++++++++++++++++++ paddle/fluid/operators/spectral_op.h | 52 ++++++++++++++++ 2 files changed, 137 insertions(+) create mode 100644 paddle/fluid/operators/spectral_op.cu create mode 100644 paddle/fluid/operators/spectral_op.h diff --git a/paddle/fluid/operators/spectral_op.cu b/paddle/fluid/operators/spectral_op.cu new file mode 100644 index 00000000000000..2af9ee8e6b3424 --- /dev/null +++ b/paddle/fluid/operators/spectral_op.cu @@ -0,0 +1,85 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include +#include "paddle/fluid/operators/spectral_op.h" +#include "paddle/fluid/platform/complex.h" + +namespace paddle { +namespace operators { + +namespace { +template +void fft_c2c_cufft(const DeviceContext& ctx, const Tensor* X, Tensor* out, + const std::vector& axes, int64_t normalization, + bool forward) { + // const auto x_dims = x->dims(); +} + +template +void fft_c2c_cufft_backward(const DeviceContext& ctx, const Tensor* X, + Tensor* out, const std::vector& axes, + int64_t normalization, bool forward) {} + +} // anonymous namespace + +template +class FFTC2CKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using U = paddle::platform::complex; + auto& dev_ctx = ctx.device_context(); + + auto axes = ctx.Attr>("axes"); + const std::string norm_str = ctx.Attr("normalization"); + const bool forward = ctx.Attr("forward"); + auto* x = ctx.Input("X"); + auto* y = ctx.Output("Out"); + + auto* y_data = y->mutable_data(ctx.GetPlace()); + auto normalization = get_norm_from_string(norm_str, forward); + + fft_c2c_cufft(dev_ctx, x, y, axes, + normalization, forward); + } +}; + +template +class FFTC2CGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using U = FFTC2CParamType; + auto& dev_ctx = ctx.device_context(); + + auto axes = ctx.Attr>("axes"); + const int64_t normalization = ctx.Attr("normalization"); + const bool forward = ctx.Attr("forward"); + auto* d_x = ctx.Output(framework::GradVarName("X")); + auto* d_y = ctx.Input(framework::GradVarName("Out")); + + auto *d_y_data = d_y->mutable_data(ctx.GetPlace() + auto normalization = get_norm_from_string(norm_str, forward); + + fft_c2c_cufft_backward(dev_ctx, d_x, d_y, + axes, normalization, forward); + } +}; +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL(fft_c2c, ops::FFTC2CKernel, + ops::FFTC2CKernel); + +REGISTER_OP_CUDA_KERNEL(fft_c2c_grad, ops::FFTC2CGradKernel, + ops::FFTC2CGradKernel); diff --git a/paddle/fluid/operators/spectral_op.h b/paddle/fluid/operators/spectral_op.h new file mode 100644 index 00000000000000..1df677c90d232c --- /dev/null +++ b/paddle/fluid/operators/spectral_op.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +enum class FFTNormMode : int64_t { + none, // No normalization + by_sqrt_n, // Divide by sqrt(signal_size) + by_n, // Divide by signal_size +}; + +// Convert normalization mode string to enum values +// NOTE: for different direction, normalization modes have different meanings. +// eg: "forward" translates to `by_n` for a forward transform and `none` for +// backward. +FFTNormMode get_norm_from_string(const std::string& norm, bool forward) { + if (!norm || *norm == "backward") { + return forward ? FFTNormMode::none : FFTNormMode::by_n; + } + + if (*norm == "forward") { + return forward ? FFTNormMode::by_n : FFTNormMode::none; + } + + if (*norm == "ortho") { + return FFTNormMode::by_sqrt_n; + } + + PADDLE_THROW(platform::errors::InvalidArgument( + "Fft norm string must be forward or backward or ortho")); +} + +template +class FFTC2CKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override; +}; + +} // namespace operators +} // namespace paddle