Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
112 commits
Select commit Hold shift + click to select a range
5dff7f9
1. add interface for fft;
Aug 10, 2021
53c2448
add fft c2c cufft kernel
jeff41404 Aug 12, 2021
3c22959
implement argument checking & op calling parts for fft_c2c and fftn_c2c
Aug 13, 2021
b07ee4c
add operator and opmaker definitions
Aug 13, 2021
579eda0
only register float and double for cpu.
Aug 13, 2021
f0d0413
Merge pull request #1 from iclementine/fft
jeff41404 Aug 13, 2021
8e51eea
Merge pull request #2 from jeff41404/fft_c2c_cufft
Aug 13, 2021
6451ac8
add common code for implementing FFT, add pocketfft as a dependency
Aug 17, 2021
4cf05fb
Merge pull request #3 from iclementine/fft
jeff41404 Aug 17, 2021
0ba4495
add fft c2c cufft kernel function
jeff41404 Aug 17, 2021
b187832
fix bugs in python interface
Aug 17, 2021
c3820f2
Merge pull request #4 from jeff41404/fft_c2c_cufft
Aug 17, 2021
034cddb
add support for c2r, r2c operators, op makers, kernels and kernel fun…
Aug 18, 2021
e4bcaed
test and fix bugs
Aug 18, 2021
b00a0f1
1. fft_c2c function: add support for onesided=False;
Aug 18, 2021
322b9e3
Merge pull request #5 from iclementine/fft
jeff41404 Aug 18, 2021
e691dca
1. fft: fix python api bugs;
Aug 18, 2021
9f0cc98
fft c2c cufft kernel done with complie and link
jeff41404 Aug 19, 2021
96e8b09
Merge pull request #6 from jeff41404/fft_c2c_cufft
Aug 19, 2021
6864616
Merge branch 'project_fft' of github.com:iclementine/Paddle into fft
Aug 19, 2021
7402387
fix shape_op, add mkl placeholder
Aug 20, 2021
5866c75
Merge pull request #7 from iclementine/fft
jeff41404 Aug 20, 2021
5f3166d
remove mkl
Aug 20, 2021
bce9488
Merge pull request #8 from iclementine/fft
jeff41404 Aug 20, 2021
8349129
complete fft c2c in gpu
Aug 20, 2021
dea0e79
Merge pull request #9 from jeff41404/fft_c2c_cufft
Aug 20, 2021
d224cc4
1. implement mkl-based fft, FFTC2CFunctor and common function exec_fft;
Aug 25, 2021
7242079
Merge pull request #10 from iclementine/fft
jeff41404 Aug 25, 2021
c200985
complete fft c2c on gpu in ND
Aug 25, 2021
721e532
Merge branch 'project_fft' of https://github.com/iclementine/Paddle i…
Aug 25, 2021
40803ea
complete fft c2c on gpu in ND
Aug 25, 2021
12f26b5
Merge pull request #11 from jeff41404/fft_c2c_cufft
Aug 25, 2021
1edd8d1
complete fft c2c backward in ND
Aug 26, 2021
3d7a9ea
Merge pull request #12 from jeff41404/fft_c2c_cufft
Aug 26, 2021
aa33104
fix MKL-based implementation
Aug 30, 2021
ff80dd4
resolve conflict
Aug 30, 2021
fc52a46
Merge pull request #14 from iclementine/fft
jeff41404 Aug 30, 2021
4a0b182
Add frame op and CPU/GPU kernels.
KPatr1ck Aug 27, 2021
dcbf6ce
Add frame op forward unittest.
KPatr1ck Aug 30, 2021
19f9f77
Add frame op forward unittest.
KPatr1ck Aug 30, 2021
930f523
Remove axis parameter in FrameFunctor.
KPatr1ck Aug 30, 2021
d89b583
Add frame op grad CPU/GPU kernels and unittest.
KPatr1ck Aug 31, 2021
d0c7911
Add frame op grad CPU/GPU kernels and unittest.
KPatr1ck Aug 31, 2021
74dc1e4
Update doc string.
KPatr1ck Aug 31, 2021
66abd71
Update after review and remove librosa requirement in unittest.
KPatr1ck Sep 1, 2021
ef984a8
Update grad kernel.
KPatr1ck Sep 1, 2021
c4c21c3
Merge pull request #13 from KPatr1ck/cxj_fft
Sep 1, 2021
4168872
add fft_c2r op
lijiaqi0612 Sep 1, 2021
d78f7ad
Remove data allocation in TransCompute function.
KPatr1ck Sep 1, 2021
9639179
Merge pull request #15 from KPatr1ck/cxj_fft
Sep 1, 2021
05124e0
add fft r2c onesided with cpu(pocketfft/mkl) and gpu
Sep 2, 2021
c8b96e5
last fft c2r functor
lijiaqi0612 Sep 2, 2021
5a46860
Merge branch 'project_fft' of https://github.com/iclementine/Paddle i…
Sep 2, 2021
e4fe761
Merge pull request #16 from lijiaqi0612/fft_c2r
Sep 4, 2021
a82eded
fix C2R and R2C for cufft, becase the direction is not an option in t…
Sep 4, 2021
9f16534
Merge pull request #17 from iclementine/fft
Sep 4, 2021
6cb27d6
add fft r2c onesided with cpu(pocketfft/mkl) and gpu
Sep 2, 2021
664c452
fix bugs in python APIs
Sep 4, 2021
efc26c9
fix fft_c2r grad kernal
lijiaqi0612 Sep 4, 2021
8816fb7
build fft_r2c oneided pass with cpu and gpu
Sep 4, 2021
f98c65c
Merge pull request #18 from cxxly/fft_r2c
Sep 4, 2021
b5a7762
Merge branch 'develop' into project_fft
Sep 4, 2021
6e05742
fix bugs in python APIs
Sep 4, 2021
ad3d6af
Merge branch 'project_fft' of https://github.com/iclementine/Paddle i…
lijiaqi0612 Sep 4, 2021
dfbce40
add cuda fft c2r grad kernal functor
lijiaqi0612 Sep 6, 2021
aa1e344
clean code
lijiaqi0612 Sep 6, 2021
6fec29b
Merge pull request #20 from lijiaqi0612/fft_c2r
Sep 6, 2021
52293b1
Merge pull request #21 from iclementine/fft
Sep 6, 2021
1b414db
Merge branch 'project_fft' of github.com:iclementine/Paddle into fft
Sep 6, 2021
d074050
fix fft_c2r python API
Sep 6, 2021
3c32139
Merge pull request #22 from iclementine/fft
Sep 6, 2021
a31dc31
fill fft r2c result with conjugate symmetry (#19)
cxxly Sep 7, 2021
316590d
add placeholder for unittests (#24)
Sep 8, 2021
75c2ca0
simple parameterize test function by auto generate test case from par…
cxxly Sep 9, 2021
7a22378
miscellaneous fixes for python APIs (#26)
Sep 9, 2021
d21249e
fix typos in axes checking (#27)
Sep 9, 2021
3e7792e
fix argument checking (#28)
Sep 9, 2021
52348da
Add C2R Python layer normal and abnormal use cases (#29)
lijiaqi0612 Sep 10, 2021
b358cc5
complete rfft,rfft2,rfftn,ihfft,ihfft2,ihfftn unittest and doc string…
cxxly Sep 10, 2021
c825b1e
Documentation of the common interfaces of c2r and c2c (#31)
lijiaqi0612 Sep 10, 2021
1cf4587
clean c++ code (#32)
Sep 10, 2021
3765c70
Add numpy-based implementation of spectral ops (#33)
Sep 10, 2021
fcd9069
Add fft_c2r numpy based implementation for unittest. (#34)
Sep 10, 2021
f9e3309
Add deframe op and stft/istft api. (#23)
KPatr1ck Sep 10, 2021
c46fa79
Add overlap_add op and stft/istft api unittest (#35)
KPatr1ck Sep 10, 2021
fa40a1a
Add unittest for fft helper functions (#36)
Sep 10, 2021
f002f3a
complete static graph unittest for all public api (#37)
cxxly Sep 10, 2021
4c3a4b1
Unittest of op with FFT C2C, C2R and r2c added (#38)
Sep 10, 2021
91fa9cd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Sep 10, 2021
a4abc4e
add fft related options to CMakeLists.txt
Sep 10, 2021
b76c98c
fix typos and clean code (#39)
Sep 11, 2021
a714cfe
always convert numpy array to paddle.Tensor to avoid comparing numpy …
Sep 14, 2021
626a5b7
fix CI Errors: numpy dtype comparison, thrust when cuda is not availa…
Sep 14, 2021
d2eebba
remove inclusion of thrust, add __all__ list for fft (#42)
Sep 15, 2021
c0289d1
Add api doc and update unittest. (#43)
KPatr1ck Sep 15, 2021
b3d5f13
fix MKL-based FFT implementation (#44)
Sep 15, 2021
2cb21c0
remove code for debug (#45)
Sep 15, 2021
5ccdf98
use dynload for cufft (#46)
Sep 15, 2021
5e33e7f
Update doc and unittest. (#47)
KPatr1ck Sep 16, 2021
cc9d3a0
use dynload for cufft (#48)
Sep 16, 2021
150da5c
fix conflicts and merge upstream (#49)
Sep 16, 2021
0279d4a
Merge branch 'develop' into project_fft
Sep 16, 2021
1e16889
fix compile error: only link dyload_cuda when cuda is available (#50)
Sep 16, 2021
e804bd5
fix dynload for cufft on windows (#51)
Sep 16, 2021
ffcf187
add NOMINMAX to compile on windows (#52)
Sep 16, 2021
6c3322c
explicitly specify capture mode for lambdas (#55)
Sep 16, 2021
d700f45
fix fft sample (#53)
lijiaqi0612 Sep 16, 2021
b85788b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Sep 16, 2021
2c615bb
update scipy and numpy version for unittests of fft (#56)
Sep 16, 2021
e968c20
Add static graph unittests of frame and overlap_add api. (#57)
KPatr1ck Sep 17, 2021
76401f4
Remove cache of cuFFT & Disable ONEMKL (#59)
Sep 17, 2021
f8c2a2e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Sep 17, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ project(paddle CXX C)
# enable language CUDA
# TODO(Shibo Tao): remove find_package(CUDA) completely.
find_package(CUDA QUIET)
find_package(MKL CONFIG QUIET)
option(WITH_ONEMKL "Compile PaddlePaddle with oneMKL" OFF)
option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND})
option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF)
option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF)
Expand Down Expand Up @@ -225,6 +227,7 @@ option(WITH_STRIP "Strip so files of Whl packages" OFF)
option(NEW_RELEASE_CUBIN "PaddlePaddle next-level release strategy for pypi cubin package" OFF)
option(NEW_RELEASE_JIT "PaddlePaddle next-level release strategy for backup jit package" OFF)
option(WITH_ASCEND_INT64 "Compile with int64 kernel for ascend NPU" OFF)
option(WITH_POCKETFFT "Compile with pocketfft support" ON)

# PY_VERSION
if(NOT PY_VERSION)
Expand Down Expand Up @@ -373,6 +376,10 @@ if (WITH_MIPS)
add_definitions(-DPADDLE_WITH_MIPS)
endif()

if (WITH_ONEMKL)
add_definitions(-DPADDLE_WITH_ONEMKL)
endif()

if (WITH_HETERPS)
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new")
Expand Down
2 changes: 1 addition & 1 deletion cmake/FindGperftools.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
find_library(GPERFTOOLS_TCMALLOC
NAMES tcmalloc
HINTS ${Gperftools_ROOT_DIR}/lib)

find_library(GPERFTOOLS_PROFILER
NAMES profiler
HINTS ${Gperftools_ROOT_DIR}/lib)
Expand Down
44 changes: 44 additions & 0 deletions cmake/external/pocketfft.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

include(ExternalProject)


set(POCKETFFT_PATH "${THIRD_PARTY_PATH}/pocketfft" CACHE STRING "A path setting for external_pocketfft path.")
set(POCKETFFT_PREFIX_DIR ${POCKETFFT_PATH})

set(POCKETFFT_REPOSITORY https://gitlab.mpcdf.mpg.de/mtr/pocketfft.git)
set(POCKETFFT_TAG release_for_eigen)

SET(POCKETFFT_INCLUDE_DIR ${POCKETFFT_PREFIX_DIR}/src)
message("POCKETFFT_INCLUDE_DIR is ${POCKETFFT_INCLUDE_DIR}")
include_directories(${POCKETFFT_INCLUDE_DIR})

ExternalProject_Add(
extern_pocketfft
${EXTERNAL_PROJECT_LOG_ARGS}
${SHALLOW_CLONE}
GIT_REPOSITORY ${POCKETFFT_REPOSITORY}
GIT_TAG ${POCKETFFT_TAG}
PREFIX ${POCKETFFT_PREFIX_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)

add_library(pocketfft INTERFACE)

add_dependencies(pocketfft extern_pocketfft)
6 changes: 6 additions & 0 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -361,4 +361,10 @@ if (WITH_CRYPTO)
add_definitions(-DPADDLE_WITH_CRYPTO)
endif (WITH_CRYPTO)

if (WITH_POCKETFFT)
include(external/pocketfft)
list(APPEND third_party_deps extern_pocketfft)
add_definitions(-DPADDLE_WITH_POCKETFFT)
endif (WITH_POCKETFFT)

add_custom_target(third_party ALL DEPENDS ${third_party_deps})
18 changes: 17 additions & 1 deletion paddle/fluid/framework/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <iostream>
#include <string>
#include <typeindex>

Expand Down Expand Up @@ -170,11 +171,26 @@ extern inline proto::VarType::Type ToComplexType(proto::VarType::Type t) {
return proto::VarType::COMPLEX128;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unknown complex value data type (%s), now only support float32 and "
"Unknown real value data type (%s), now only support float32 and "
"float64.",
DataTypeToString(t)));
}
}

extern inline proto::VarType::Type ToRealType(proto::VarType::Type t) {
switch (t) {
case proto::VarType::COMPLEX64:
return proto::VarType::FP32;
case proto::VarType::COMPLEX128:
return proto::VarType::FP64;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unknown complex value data type (%s), now only support complex64 "
"and "
"complex128.",
DataTypeToString(t)));
}
}

} // namespace framework
} // namespace paddle
12 changes: 11 additions & 1 deletion paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ if (WITH_GPU)
endif()
endif()

if (WITH_POCKETFFT)
SET(OP_HEADER_DEPS ${OP_HEADER_DEPS} pocketfft)
endif()


SET(OP_MKL_DEPS "")
if (NOT WITH_MKL OR NOT WITH_AVX)
Expand All @@ -75,7 +79,7 @@ if(WITH_UNITY_BUILD)
endif()

register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op lstm_op run_program_op eye_op recurrent_op
sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})

op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS})

Expand All @@ -94,6 +98,12 @@ else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()

if (WITH_GPU AND (NOT WITH_ROCM))
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS})
else()
op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS})
endif()

op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute)
op_library(eye_op DEPS ${OP_HEADER_DEPS})
op_library(recurrent_op DEPS ${OP_HEADER_DEPS})
Expand Down
13 changes: 11 additions & 2 deletions paddle/fluid/operators/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/fluid/operators/concat_op.h"

#include <paddle/fluid/platform/complex.h>
#include <memory>
#include <string>
#include <vector>
Expand Down Expand Up @@ -237,7 +238,11 @@ REGISTER_OP_CPU_KERNEL(
ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, uint8_t>);
ops::ConcatKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, double>,
Expand All @@ -247,4 +252,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, uint8_t>);
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
13 changes: 11 additions & 2 deletions paddle/fluid/operators/concat_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/concat_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"

namespace ops = paddle::operators;
Expand All @@ -24,7 +25,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::ConcatKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, uint8_t>);
ops::ConcatKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext,
plat::complex<float>>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext,
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
concat_grad,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, double>,
Expand All @@ -33,4 +38,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, uint8_t>);
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext,
plat::complex<float>>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext,
plat::complex<double>>);
3 changes: 3 additions & 0 deletions paddle/fluid/operators/eigen/scale.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -42,6 +43,8 @@ template struct EigenScale<Eigen::DefaultDevice, int8_t>;
template struct EigenScale<Eigen::DefaultDevice, int16_t>;
template struct EigenScale<Eigen::DefaultDevice, int>;
template struct EigenScale<Eigen::DefaultDevice, int64_t>;
template struct EigenScale<Eigen::DefaultDevice, platform::complex<float>>;
template struct EigenScale<Eigen::DefaultDevice, platform::complex<double>>;

} // namespace operators
} // namespace paddle
3 changes: 3 additions & 0 deletions paddle/fluid/operators/eigen/scale.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
Expand Down Expand Up @@ -41,6 +42,8 @@ template struct EigenScale<Eigen::GpuDevice, int16_t>;
template struct EigenScale<Eigen::GpuDevice, int>;
template struct EigenScale<Eigen::GpuDevice, int64_t>;
template struct EigenScale<Eigen::GpuDevice, platform::float16>;
template struct EigenScale<Eigen::GpuDevice, platform::complex<float>>;
template struct EigenScale<Eigen::GpuDevice, platform::complex<double>>;

} // namespace operators
} // namespace paddle
13 changes: 11 additions & 2 deletions paddle/fluid/operators/fill_zeros_like_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fill_zeros_like_op.h"
#include "paddle/fluid/platform/complex.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -93,12 +94,20 @@ REGISTER_OP_CPU_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>);
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);

REGISTER_OP_CPU_KERNEL(
fill_zeros_like2,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>);
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
13 changes: 11 additions & 2 deletions paddle/fluid/operators/fill_zeros_like_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/fluid/operators/fill_zeros_like_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"

namespace ops = paddle::operators;
Expand All @@ -25,7 +26,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>);
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);

REGISTER_OP_CUDA_KERNEL(
fill_zeros_like2,
Expand All @@ -35,4 +40,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>);
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
6 changes: 5 additions & 1 deletion paddle/fluid/operators/flip_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/complex.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -145,6 +146,7 @@ class FlipOpGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(flip, ops::FlipOp, ops::FlipOpMaker, ops::FlipOpInferVarType,
ops::FlipOpGradMaker<paddle::framework::OpDesc>,
ops::FlipOpGradMaker<paddle::imperative::OpBase>);
Expand All @@ -153,7 +155,9 @@ REGISTER_OP_CPU_KERNEL(
ops::FlipKernel<paddle::platform::CPUDeviceContext, double>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, bool>);
ops::FlipKernel<paddle::platform::CPUDeviceContext, bool>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, plat::complex<float>>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, plat::complex<double>>);

/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(flip)
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/operators/flip_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#include <vector>
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/complex.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -163,4 +164,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::FlipKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, int>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, bool>);
ops::FlipKernel<paddle::platform::CUDADeviceContext, bool>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, plat::complex<float>>,
ops::FlipKernel<paddle::platform::CUDADeviceContext,
plat::complex<double>>);
Loading