Skip to content

Commit 7a22378

Browse files
author
Feiyu Chan
authored
miscellaneous fixes for python APIs (#26)
* add placeholder for unittests * resize fft inputs before computation is n or s is provided. * add complex kernels for pad and pad_grad * simplify argument checking. * add type promotion * add int to float or complex promotion * fix output data type for static mode * fix fft's input dtype dispatch, import fft to paddle
1 parent 75c2ca0 commit 7a22378

File tree

5 files changed

+246
-168
lines changed

5 files changed

+246
-168
lines changed

paddle/fluid/operators/pad_op.cc

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/pad_op.h"
1616
#include <memory>
17+
#include "paddle/fluid/platform/complex.h"
1718

1819
namespace paddle {
1920
namespace operators {
@@ -170,20 +171,36 @@ REGISTER_OP_CPU_KERNEL(
170171
pad, ops::PadKernel<paddle::platform::CPUDeviceContext, float>,
171172
ops::PadKernel<paddle::platform::CPUDeviceContext, double>,
172173
ops::PadKernel<paddle::platform::CPUDeviceContext, int>,
173-
ops::PadKernel<paddle::platform::CPUDeviceContext, int64_t>);
174+
ops::PadKernel<paddle::platform::CPUDeviceContext, int64_t>,
175+
ops::PadKernel<paddle::platform::CPUDeviceContext,
176+
paddle::platform::complex<float>>,
177+
ops::PadKernel<paddle::platform::CPUDeviceContext,
178+
paddle::platform::complex<double>>);
174179
REGISTER_OP_CPU_KERNEL(
175180
pad_grad, ops::PadGradKernel<paddle::platform::CPUDeviceContext, float>,
176-
ops::PadGradKernel<paddle::platform::CPUDeviceContext, double>);
181+
ops::PadGradKernel<paddle::platform::CPUDeviceContext, double>,
182+
ops::PadGradKernel<paddle::platform::CPUDeviceContext,
183+
paddle::platform::complex<float>>,
184+
ops::PadGradKernel<paddle::platform::CPUDeviceContext,
185+
paddle::platform::complex<double>>);
177186

178187
REGISTER_OP_CUDA_KERNEL(
179188
pad, ops::PadKernel<paddle::platform::CUDADeviceContext, double>,
180189
ops::PadKernel<paddle::platform::CUDADeviceContext, float>,
181190
ops::PadKernel<paddle::platform::CUDADeviceContext, int>,
182191
ops::PadKernel<paddle::platform::CUDADeviceContext, int64_t>,
183192
ops::PadKernel<paddle::platform::CUDADeviceContext,
184-
paddle::platform::float16>);
193+
paddle::platform::float16>,
194+
ops::PadKernel<paddle::platform::CUDADeviceContext,
195+
paddle::platform::complex<float>>,
196+
ops::PadKernel<paddle::platform::CUDADeviceContext,
197+
paddle::platform::complex<double>>);
185198
REGISTER_OP_CUDA_KERNEL(
186199
pad_grad, ops::PadGradKernel<paddle::platform::CUDADeviceContext, double>,
187200
ops::PadGradKernel<paddle::platform::CUDADeviceContext, float>,
188201
ops::PadGradKernel<paddle::platform::CUDADeviceContext,
189-
paddle::platform::float16>);
202+
paddle::platform::float16>,
203+
ops::PadGradKernel<paddle::platform::CUDADeviceContext,
204+
paddle::platform::complex<float>>,
205+
ops::PadGradKernel<paddle::platform::CUDADeviceContext,
206+
paddle::platform::complex<double>>);

paddle/fluid/operators/spectral_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,10 @@ class FFTR2COp : public framework::OperatorWithKernel {
150150
void InferShape(framework::InferShapeContext* ctx) const override {
151151
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
152152
platform::errors::InvalidArgument(
153-
"Input(%s) of FFTC2ROp should not be null.", "X"));
153+
"Input(%s) of FFTR2COp should not be null.", "X"));
154154
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
155155
platform::errors::InvalidArgument(
156-
"Output(%s) of FFTC2ROp should not be null.", "Out"));
156+
"Output(%s) of FFTR2COp should not be null.", "Out"));
157157
const auto axes = ctx->Attrs().Get<std::vector<int64_t>>("axes");
158158
const bool onesided = ctx->Attrs().Get<bool>("onesided");
159159
if (!onesided) {

python/paddle/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
import paddle.static # noqa: F401
6565
import paddle.vision # noqa: F401
6666

67+
from .tensor import fft
6768
from .tensor.random import bernoulli # noqa: F401
6869

6970
from .tensor.attribute import rank # noqa: F401

python/paddle/tensor/attribute.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@ def _complex_to_real_dtype(dtype):
3535
return dtype
3636

3737

38+
def _real_to_complex_dtype(dtype):
39+
if dtype == core.VarDesc.VarType.FP32:
40+
return core.VarDesc.VarType.COMPLEX64
41+
elif dtype == core.VarDesc.VarType.FP64:
42+
return core.VarDesc.VarType.COMPLEX128
43+
else:
44+
return dtype
45+
46+
3847
def is_complex(x):
3948
dtype = x.dtype
4049
is_complex_dtype = (dtype == core.VarDesc.VarType.COMPLEX64 or
@@ -51,6 +60,16 @@ def is_floating_point(x):
5160
return is_fp_dtype
5261

5362

63+
def is_interger(x):
64+
dtype = x.dtype
65+
is_int_dtype = (dtype == core.VarDesc.VarType.UINT8 or
66+
dtype == core.VarDesc.VarType.INT8 or
67+
dtype == core.VarDesc.VarType.INT16 or
68+
dtype == core.VarDesc.VarType.INT32 or
69+
dtype == core.VarDesc.VarType.INT64)
70+
return is_int_dtype
71+
72+
5473
def real(x, name=None):
5574
"""
5675
Returns a new tensor containing real values of the input tensor.

0 commit comments

Comments
 (0)