@@ -240,15 +240,29 @@ class FFTC2ROp : public framework::OperatorWithKernel {
240240 OP_INOUT_CHECK (ctx->HasOutput (" Out" ), " Output" , " Out" , " fft_c2r" );
241241
242242 const auto axes = ctx->Attrs ().Get <std::vector<int64_t >>(" axes" );
243+ const auto x_dim = ctx->GetInputDim (" X" );
244+ for (size_t i = 0 ; i < axes.size () - 1L ; i++) {
245+ const auto fft_n_point = (x_dim[axes[i]] - 1 ) * 2 ;
246+ PADDLE_ENFORCE_GT (fft_n_point, 0 ,
247+ platform::errors::InvalidArgument (
248+ " Invalid fft n-point (%d)." , fft_n_point));
249+ }
243250
244251 const int64_t last_dim_size = ctx->Attrs ().Get <int64_t >(" last_dim_size" );
245252 framework::DDim out_dim (ctx->GetInputDim (" X" ));
246253 const int64_t last_fft_axis = axes.back ();
247254 if (last_dim_size == 0 ) {
248255 const int64_t last_fft_dim_size = out_dim.at (last_fft_axis);
249- out_dim.at (last_fft_axis) = (last_fft_dim_size - 1 ) * 2 ;
256+ const int64_t fft_n_point = (last_fft_dim_size - 1 ) * 2 ;
257+ PADDLE_ENFORCE_GT (fft_n_point, 0 ,
258+ platform::errors::InvalidArgument (
259+ " Invalid fft n-point (%d)." , fft_n_point));
260+ out_dim.at (last_fft_axis) = fft_n_point;
250261 } else {
251- out_dim.at (last_fft_axis) = ctx->Attrs ().Get <int64_t >(" last_dim_size" );
262+ PADDLE_ENFORCE_GT (last_dim_size, 0 ,
263+ platform::errors::InvalidArgument (
264+ " Invalid fft n-point (%d)." , last_dim_size));
265+ out_dim.at (last_fft_axis) = last_dim_size;
252266 }
253267 ctx->SetOutputDim (" Out" , out_dim);
254268 }
@@ -681,11 +695,11 @@ struct FFTC2CFunctor<platform::CPUDeviceContext, Ti, To> {
681695 const auto & input_dim = x->dims ();
682696 const std::vector<size_t > in_sizes =
683697 framework::vectorize<size_t >(input_dim);
684- std::vector<int64_t > in_strides =
685- framework::vectorize<int64_t >(framework::stride (input_dim));
698+ std::vector<std:: ptrdiff_t > in_strides =
699+ framework::vectorize<std:: ptrdiff_t >(framework::stride (input_dim));
686700 const int64_t data_size = sizeof (C);
687701 std::transform (in_strides.begin (), in_strides.end (), in_strides.begin (),
688- [](int64_t s) { return s * data_size; });
702+ [](std:: ptrdiff_t s) { return s * data_size; });
689703
690704 const auto * in_data = reinterpret_cast <const C*>(x->data <Ti>());
691705 auto * out_data = reinterpret_cast <C*>(out->data <To>());
@@ -714,24 +728,24 @@ struct FFTR2CFunctor<platform::CPUDeviceContext, Ti, To> {
714728 const auto & input_dim = x->dims ();
715729 const std::vector<size_t > in_sizes =
716730 framework::vectorize<size_t >(input_dim);
717- std::vector<int64_t > in_strides =
718- framework::vectorize<int64_t >(framework::stride (input_dim));
731+ std::vector<std:: ptrdiff_t > in_strides =
732+ framework::vectorize<std:: ptrdiff_t >(framework::stride (input_dim));
719733 {
720734 const int64_t data_size = sizeof (R);
721735 std::transform (in_strides.begin (), in_strides.end (), in_strides.begin (),
722- [](int64_t s) { return s * data_size; });
736+ [](std:: ptrdiff_t s) { return s * data_size; });
723737 }
724738
725739 const auto & output_dim = out->dims ();
726740 const std::vector<size_t > out_sizes =
727741 framework::vectorize<size_t >(output_dim);
728- std::vector<int64_t > out_strides =
729- framework::vectorize<int64_t >(framework::stride (output_dim));
742+ std::vector<std:: ptrdiff_t > out_strides =
743+ framework::vectorize<std:: ptrdiff_t >(framework::stride (output_dim));
730744 {
731745 const int64_t data_size = sizeof (C);
732746 std::transform (out_strides.begin (), out_strides.end (),
733747 out_strides.begin (),
734- [](int64_t s) { return s * data_size; });
748+ [](std:: ptrdiff_t s) { return s * data_size; });
735749 }
736750
737751 const auto * in_data = x->data <R>();
@@ -761,24 +775,24 @@ struct FFTC2RFunctor<platform::CPUDeviceContext, Ti, To> {
761775 const auto & input_dim = x->dims ();
762776 const std::vector<size_t > in_sizes =
763777 framework::vectorize<size_t >(input_dim);
764- std::vector<int64_t > in_strides =
765- framework::vectorize<int64_t >(framework::stride (input_dim));
778+ std::vector<std:: ptrdiff_t > in_strides =
779+ framework::vectorize<std:: ptrdiff_t >(framework::stride (input_dim));
766780 {
767781 const int64_t data_size = sizeof (C);
768782 std::transform (in_strides.begin (), in_strides.end (), in_strides.begin (),
769- [](int64_t s) { return s * data_size; });
783+ [](std:: ptrdiff_t s) { return s * data_size; });
770784 }
771785
772786 const auto & output_dim = out->dims ();
773787 const std::vector<size_t > out_sizes =
774788 framework::vectorize<size_t >(output_dim);
775- std::vector<int64_t > out_strides =
776- framework::vectorize<int64_t >(framework::stride (output_dim));
789+ std::vector<std:: ptrdiff_t > out_strides =
790+ framework::vectorize<std:: ptrdiff_t >(framework::stride (output_dim));
777791 {
778792 const int64_t data_size = sizeof (R);
779793 std::transform (out_strides.begin (), out_strides.end (),
780794 out_strides.begin (),
781- [](int64_t s) { return s * data_size; });
795+ [](std:: ptrdiff_t s) { return s * data_size; });
782796 }
783797
784798 const auto * in_data = reinterpret_cast <const C*>(x->data <Ti>());
0 commit comments