diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 208d0615b23cec..83207754709e4a 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -590,21 +590,6 @@ void CumWithIndicesInferMeta(const MetaTensor& x, phi::errors::InvalidArgument( "dtype of indices must be DataType::INT32 or DataType::INT64")); - if (dtype == DataType::INT32) { - int _axis = 0; - if (axis < 0) { - _axis = axis + x_dims.size(); - } else { - _axis = axis; - } - PADDLE_ENFORCE_LT( - common::vectorize(x_dims)[_axis], - INT32_MAX, - phi::errors::OutOfRange( - "cummax with axis %ld may be overflow, set dtype int64 to continue", - axis)); - } - if (x_dims.size() > 0) { PADDLE_ENFORCE_GE( axis, @@ -633,6 +618,21 @@ void CumWithIndicesInferMeta(const MetaTensor& x, axis)); } + if (dtype == DataType::INT32) { + int _axis = 0; + if (axis < 0) { + _axis = axis + x_dims.size(); + } else { + _axis = axis; + } + PADDLE_ENFORCE_LT( + common::vectorize(x_dims)[_axis], + INT32_MAX, + phi::errors::OutOfRange( + "cummax with axis %ld may be overflow, set dtype int64 to continue", + axis)); + } + out->set_dims(x_dims); out->set_dtype(x.dtype()); out->share_lod(x);