Skip to content

Commit b3d5f13

Browse files
author
Feiyu Chan
authored
fix MKL-based FFT implementation (#44)
* fix MKL-based FFT implementation, MKL CDFT's FORWARD DOMAIN is always REAL for R2C and C2R
1 parent c0289d1 commit b3d5f13

File tree

1 file changed

+62
-18
lines changed

1 file changed

+62
-18
lines changed

paddle/fluid/operators/spectral_op.cc

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -399,15 +399,20 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
399399
}
400400
}();
401401

402-
const bool complex_input = framework::IsComplexType(in_dtype);
403-
const bool complex_output = framework::IsComplexType(out_dtype);
404-
const DFTI_CONFIG_VALUE domain = [&] {
405-
if (forward) {
406-
return complex_input ? DFTI_COMPLEX : DFTI_REAL;
407-
} else {
408-
return complex_output ? DFTI_COMPLEX : DFTI_REAL;
409-
}
410-
}();
402+
// C2C, R2C, C2R
403+
const FFTTransformType fft_type = GetFFTTransformType(in_dtype, out_dtype);
404+
const DFTI_CONFIG_VALUE domain =
405+
(fft_type == FFTTransformType::C2C) ? DFTI_COMPLEX : DFTI_REAL;
406+
407+
// const bool complex_input = framework::IsComplexType(in_dtype);
408+
// const bool complex_output = framework::IsComplexType(out_dtype);
409+
// const DFTI_CONFIG_VALUE domain = [&] {
410+
// if (forward) {
411+
// return complex_input ? DFTI_COMPLEX : DFTI_REAL;
412+
// } else {
413+
// return complex_output ? DFTI_COMPLEX : DFTI_REAL;
414+
// }
415+
// }();
411416

412417
DftiDescriptor descriptor;
413418
std::vector<MKL_LONG> fft_sizes(signal_sizes.cbegin(), signal_sizes.cend());
@@ -442,7 +447,7 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
442447
mkl_out_stride.data()));
443448

444449
// conjugate even storage
445-
if (!complex_input || !complex_output) {
450+
if (!(fft_type == FFTTransformType::C2C)) {
446451
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE,
447452
DFTI_COMPLEX_COMPLEX));
448453
}
@@ -455,8 +460,16 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
455460
((normalization == FFTNormMode::by_sqrt_n)
456461
? 1.0 / std::sqrt(static_cast<double>(signal_numel))
457462
: 1.0 / static_cast<double>(signal_numel));
458-
const auto scale_direction =
459-
forward ? DFTI_FORWARD_SCALE : DFTI_BACKWARD_SCALE;
463+
const auto scale_direction = [&]() {
464+
if (fft_type == FFTTransformType::R2C ||
465+
(fft_type == FFTTransformType::C2C && forward)) {
466+
return DFTI_FORWARD_SCALE;
467+
} else {
468+
// (fft_type == FFTTransformType::C2R ||
469+
// (fft_type == FFTTransformType::C2C && !forward))
470+
return DFTI_BACKWARD_SCALE;
471+
}
472+
}();
460473
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), scale_direction, scale));
461474
}
462475

@@ -541,13 +554,44 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
541554
DftiDescriptor desc =
542555
_plan_mkl_fft(x->type(), out->type(), input_stride, output_stride,
543556
signal_sizes, normalization, forward);
544-
// dump_descriptor(desc.get());
545-
if (forward) {
557+
dump_descriptor(desc.get());
558+
559+
const FFTTransformType fft_type = GetFFTTransformType(x->type(), out->type());
560+
if (fft_type == FFTTransformType::C2R && forward) {
561+
framework::Tensor collapsed_input_conj(collapsed_input.type());
562+
collapsed_input_conj.mutable_data<Ti>(collapsed_input.dims(),
563+
ctx.GetPlace());
564+
// conjugate the input
565+
platform::ForRange<DeviceContext> for_range(ctx, collapsed_input.numel());
566+
math::ConjFunctor<Ti> functor(collapsed_input.data<Ti>(),
567+
collapsed_input.numel(),
568+
collapsed_input_conj.data<Ti>());
569+
for_range(functor);
570+
MKL_DFTI_CHECK(DftiComputeBackward(desc.get(),
571+
collapsed_input_conj.data<void>(),
572+
collapsed_output.data<void>()));
573+
} else if (fft_type == FFTTransformType::R2C && !forward) {
574+
framework::Tensor collapsed_output_conj(collapsed_output.type());
575+
collapsed_output_conj.mutable_data<To>(collapsed_output.dims(),
576+
ctx.GetPlace());
546577
MKL_DFTI_CHECK(DftiComputeForward(desc.get(), collapsed_input.data<void>(),
547-
collapsed_output.data<void>()));
578+
collapsed_output_conj.data<void>()));
579+
// conjugate the output
580+
platform::ForRange<DeviceContext> for_range(ctx, collapsed_output.numel());
581+
math::ConjFunctor<To> functor(collapsed_output_conj.data<To>(),
582+
collapsed_output.numel(),
583+
collapsed_output.data<To>());
584+
for_range(functor);
548585
} else {
549-
MKL_DFTI_CHECK(DftiComputeBackward(desc.get(), collapsed_input.data<void>(),
550-
collapsed_output.data<void>()));
586+
if (forward) {
587+
MKL_DFTI_CHECK(DftiComputeForward(desc.get(),
588+
collapsed_input.data<void>(),
589+
collapsed_output.data<void>()));
590+
} else {
591+
MKL_DFTI_CHECK(DftiComputeBackward(desc.get(),
592+
collapsed_input.data<void>(),
593+
collapsed_output.data<void>()));
594+
}
551595
}
552596

553597
// resize for the collapsed output
@@ -598,7 +642,7 @@ struct FFTC2RFunctor<platform::CPUDeviceContext, Ti, To> {
598642
FFTC2CFunctor<platform::CPUDeviceContext, Ti, Ti> c2c_functor;
599643
c2c_functor(ctx, x, &temp, c2c_dims, normalization, forward);
600644

601-
const std::vector<int64_t> new_axes(axes.back());
645+
const std::vector<int64_t> new_axes{axes.back()};
602646
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, &temp, out, new_axes,
603647
normalization, forward);
604648
} else {

0 commit comments

Comments
 (0)