@@ -399,15 +399,20 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
399399 }
400400 }();
401401
402- const bool complex_input = framework::IsComplexType (in_dtype);
403- const bool complex_output = framework::IsComplexType (out_dtype);
404- const DFTI_CONFIG_VALUE domain = [&] {
405- if (forward) {
406- return complex_input ? DFTI_COMPLEX : DFTI_REAL;
407- } else {
408- return complex_output ? DFTI_COMPLEX : DFTI_REAL;
409- }
410- }();
402+ // C2C, R2C, C2R
403+ const FFTTransformType fft_type = GetFFTTransformType (in_dtype, out_dtype);
404+ const DFTI_CONFIG_VALUE domain =
405+ (fft_type == FFTTransformType::C2C) ? DFTI_COMPLEX : DFTI_REAL;
406+
407+ // const bool complex_input = framework::IsComplexType(in_dtype);
408+ // const bool complex_output = framework::IsComplexType(out_dtype);
409+ // const DFTI_CONFIG_VALUE domain = [&] {
410+ // if (forward) {
411+ // return complex_input ? DFTI_COMPLEX : DFTI_REAL;
412+ // } else {
413+ // return complex_output ? DFTI_COMPLEX : DFTI_REAL;
414+ // }
415+ // }();
411416
412417 DftiDescriptor descriptor;
413418 std::vector<MKL_LONG> fft_sizes (signal_sizes.cbegin (), signal_sizes.cend ());
@@ -442,7 +447,7 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
442447 mkl_out_stride.data ()));
443448
444449 // conjugate even storage
445- if (!complex_input || !complex_output ) {
450+ if (!(fft_type == FFTTransformType::C2C) ) {
446451 MKL_DFTI_CHECK (DftiSetValue (descriptor.get (), DFTI_CONJUGATE_EVEN_STORAGE,
447452 DFTI_COMPLEX_COMPLEX));
448453 }
@@ -455,8 +460,16 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
455460 ((normalization == FFTNormMode::by_sqrt_n)
456461 ? 1.0 / std::sqrt (static_cast <double >(signal_numel))
457462 : 1.0 / static_cast <double >(signal_numel));
458- const auto scale_direction =
459- forward ? DFTI_FORWARD_SCALE : DFTI_BACKWARD_SCALE;
463+ const auto scale_direction = [&]() {
464+ if (fft_type == FFTTransformType::R2C ||
465+ (fft_type == FFTTransformType::C2C && forward)) {
466+ return DFTI_FORWARD_SCALE;
467+ } else {
468+ // (fft_type == FFTTransformType::C2R ||
469+ // (fft_type == FFTTransformType::C2C && !forward))
470+ return DFTI_BACKWARD_SCALE;
471+ }
472+ }();
460473 MKL_DFTI_CHECK (DftiSetValue (descriptor.get (), scale_direction, scale));
461474 }
462475
@@ -541,13 +554,44 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
541554 DftiDescriptor desc =
542555 _plan_mkl_fft (x->type (), out->type (), input_stride, output_stride,
543556 signal_sizes, normalization, forward);
544- // dump_descriptor(desc.get());
545- if (forward) {
557+ dump_descriptor (desc.get ());
558+
559+ const FFTTransformType fft_type = GetFFTTransformType (x->type (), out->type ());
560+ if (fft_type == FFTTransformType::C2R && forward) {
561+ framework::Tensor collapsed_input_conj (collapsed_input.type ());
562+ collapsed_input_conj.mutable_data <Ti>(collapsed_input.dims (),
563+ ctx.GetPlace ());
564+ // conjugate the input
565+ platform::ForRange<DeviceContext> for_range (ctx, collapsed_input.numel ());
566+ math::ConjFunctor<Ti> functor (collapsed_input.data <Ti>(),
567+ collapsed_input.numel (),
568+ collapsed_input_conj.data <Ti>());
569+ for_range (functor);
570+ MKL_DFTI_CHECK (DftiComputeBackward (desc.get (),
571+ collapsed_input_conj.data <void >(),
572+ collapsed_output.data <void >()));
573+ } else if (fft_type == FFTTransformType::R2C && !forward) {
574+ framework::Tensor collapsed_output_conj (collapsed_output.type ());
575+ collapsed_output_conj.mutable_data <To>(collapsed_output.dims (),
576+ ctx.GetPlace ());
546577 MKL_DFTI_CHECK (DftiComputeForward (desc.get (), collapsed_input.data <void >(),
547- collapsed_output.data <void >()));
578+ collapsed_output_conj.data <void >()));
579+ // conjugate the output
580+ platform::ForRange<DeviceContext> for_range (ctx, collapsed_output.numel ());
581+ math::ConjFunctor<To> functor (collapsed_output_conj.data <To>(),
582+ collapsed_output.numel (),
583+ collapsed_output.data <To>());
584+ for_range (functor);
548585 } else {
549- MKL_DFTI_CHECK (DftiComputeBackward (desc.get (), collapsed_input.data <void >(),
550- collapsed_output.data <void >()));
586+ if (forward) {
587+ MKL_DFTI_CHECK (DftiComputeForward (desc.get (),
588+ collapsed_input.data <void >(),
589+ collapsed_output.data <void >()));
590+ } else {
591+ MKL_DFTI_CHECK (DftiComputeBackward (desc.get (),
592+ collapsed_input.data <void >(),
593+ collapsed_output.data <void >()));
594+ }
551595 }
552596
553597 // resize for the collapsed output
@@ -598,7 +642,7 @@ struct FFTC2RFunctor<platform::CPUDeviceContext, Ti, To> {
598642 FFTC2CFunctor<platform::CPUDeviceContext, Ti, Ti> c2c_functor;
599643 c2c_functor (ctx, x, &temp, c2c_dims, normalization, forward);
600644
601- const std::vector<int64_t > new_axes ( axes.back ()) ;
645+ const std::vector<int64_t > new_axes{ axes.back ()} ;
602646 exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, &temp, out, new_axes,
603647 normalization, forward);
604648 } else {
0 commit comments