diff --git a/paddle/fluid/operators/cum_op.cc b/paddle/fluid/operators/cum_op.cc index 54e7a374338c25..2987c00b8c3a76 100644 --- a/paddle/fluid/operators/cum_op.cc +++ b/paddle/fluid/operators/cum_op.cc @@ -74,6 +74,62 @@ class CumsumGradMaker : public framework::SingleGradOpMaker { } }; +class CummaxOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; +}; + +class CummaxOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "Input of cummax operator"); + // AddOutput("Out_values", "Output values of cummax operator"); + // AddOutput("Out_indices", "Output indices of cummax operator"); + AddOutput("Out", "Output values of cummax operator"); + AddAttr("axis", + "The dimension to operate along. -1 means the last " + "dimension [default: -1].") + .SetDefault(-1); + AddAttr("flatten", + "Whether to compute the cummax over the flattened array. " + "[default: false].") + .SetDefault(false); + AddAttr("exclusive", + "Whether to perform exclusive cummax. [default: false].") + .SetDefault(false); + AddAttr("reverse", + "If true, the cummax is performed in the reversed direction. " + "[default: false].") + .SetDefault(false); + AddComment(R"DOC( +The cumulative maximum and corresponding index of the elements along a given axis. +By default, the first element of the out_values is the same as the first element of +the input. If exclusive is true, the first element of the result is 0. +)DOC"); + } +}; + +template +class CummaxGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("cummax"); + // grad_op->SetInput("X", this->OutputGrad("Out_values")); + grad_op->SetInput("X", this->OutputGrad("Out")); + grad_op->SetOutput("Out", this->InputGrad("X")); + grad_op->SetAttr("axis", PADDLE_GET_CONST(int, this->GetAttr("axis"))); + grad_op->SetAttr("flatten", + PADDLE_GET_CONST(bool, this->GetAttr("flatten"))); + grad_op->SetAttr("reverse", + !PADDLE_GET_CONST(bool, this->GetAttr("reverse"))); + grad_op->SetAttr("exclusive", + PADDLE_GET_CONST(bool, this->GetAttr("exclusive"))); + } +}; + class LogcumsumexpOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -149,6 +205,9 @@ using CPU = phi::CPUContext; DECLARE_INFER_SHAPE_FUNCTOR(cumsum, CumsumInferShapeFunctor, PD_INFER_META(phi::CumInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(cummax, + CummaxInferShapeFunctor, + PD_INFER_META(phi::CummaxInferMeta)); DECLARE_INFER_SHAPE_FUNCTOR(logcumsumexp, LogcumsumexpInferShapeFunctor, PD_INFER_META(phi::CumInferMeta)); @@ -158,6 +217,12 @@ REGISTER_OPERATOR(cumsum, ops::CumsumGradMaker, ops::CumsumGradMaker, CumsumInferShapeFunctor); +REGISTER_OPERATOR(cummax, + ops::CummaxOp, + ops::CummaxOpMaker, + ops::CummaxGradOpMaker, + ops::CummaxGradOpMaker, + CummaxInferShapeFunctor); REGISTER_OPERATOR(logcumsumexp, ops::CumOp, ops::LogcumsumexpOpMaker, diff --git a/paddle/phi/api/yaml/api.yaml b/paddle/phi/api/yaml/api.yaml index 1156206ee4b51a..a64e8b9a24775c 100644 --- a/paddle/phi/api/yaml/api.yaml +++ b/paddle/phi/api/yaml/api.yaml @@ -43,6 +43,15 @@ data_type : x backward : cross_grad +- api : cummax + args : (Tensor x, int axis, bool flatten, bool exclusive, bool reverse) + output : Tensor(out) + infer_meta : + func : CummaxInferMeta + kernel : + func : cummax + backward : cummax_grad + - api : diag args : (Tensor x, int offset = 0, float padding_value = 0.0) output : Tensor diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 53cdc97a716d7a..dbb9614600d4ff 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -39,6 +39,15 @@ func : cross_grad data_type : out_grad +- backward_api : cummax_grad + forward : cummax(Tensor x, int axis, bool flatten, bool exclusive, bool reverse) -> Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] + args : (Tensor out_grad, int axis, bool flatten, bool exclusive, bool reverse) + output : Tensor(x_grad) + invoke : cummax(out_grad, axis, flatten, exclusive, !reverse) + - backward_api : diag_grad forward : diag (Tensor x, int offset, float padding_value) -> Tensor(out) args : (Tensor x, Tensor out_grad, int offset) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 74705c3759da3f..d05e776e5a5ecc 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -345,6 +345,33 @@ void CumInferMeta(const MetaTensor& x, out->share_lod(x); } +void CummaxInferMeta(const MetaTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + // MetaTensor* out_values, + // MetaTensor* out_indices + MetaTensor* out) { + auto x_dims = x.dims(); + if (flatten) { + // out_values->set_dims(phi::make_ddim({phi::product(x_dims)})); + // out_indices->set_dims(phi::make_ddim({phi::product(x_dims)})); + out->set_dims(phi::make_ddim({phi::product(x_dims)})); + } else { + // out_values->set_dims(x_dims); + // out_indices->set_dims(x_dims); + out->set_dims(x_dims); + } + // out_values->set_dtype(x.dtype()); + // out_indices->set_dtype(DataType::INT64); + out->set_dtype(x.dtype()); + + // out_values->share_lod(x); + // out_indices->share_lod(x); + out->share_lod(x); +} + void CropTensorInferMeta(const MetaTensor& x, const IntArray& shape, const IntArray& offsets, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index bd35855a431298..4cc124b51b74a6 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -84,6 +84,15 @@ void CumInferMeta(const MetaTensor& x, bool reverse, MetaTensor* out); +void CummaxInferMeta(const MetaTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + // MetaTensor* out_values, + // MetaTensor* out_indices); + MetaTensor* out); + void DecodeJpegInferMeta(const MetaTensor& x, const std::string& mode, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/cum_kernel.cc b/paddle/phi/kernels/cpu/cum_kernel.cc index cd171cc8fc5fc8..959e211b0a7c51 100644 --- a/paddle/phi/kernels/cpu/cum_kernel.cc +++ b/paddle/phi/kernels/cpu/cum_kernel.cc @@ -146,6 +146,20 @@ void CumsumKernel(const Context& dev_ctx, dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out); } +template +void CummaxKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* out) { + using Reducer = Eigen::internal::MaxReducer; + auto reducer = Reducer(); + ScanKernel( + dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out); +} + template struct LogSumExp { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, @@ -267,5 +281,19 @@ PD_REGISTER_KERNEL(cumsum, int, int64_t) {} -PD_REGISTER_KERNEL( - logcumsumexp, CPU, ALL_LAYOUT, phi::LogcumsumexpKernel, float, double) {} +PD_REGISTER_KERNEL(cummax, + CPU, + ALL_LAYOUT, + phi::CummaxKernel, + float, + double, + int16_t, + int, + int64_t) {} + +PD_REGISTER_KERNEL(logcumsumexp, + CPU, + ALL_LAYOUT, + phi::LogcumsumexpKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cum_kernel.h b/paddle/phi/kernels/cum_kernel.h index 38cdbd7787bafa..9e3ff13eb092a9 100644 --- a/paddle/phi/kernels/cum_kernel.h +++ b/paddle/phi/kernels/cum_kernel.h @@ -27,6 +27,15 @@ void CumsumKernel(const Context& dev_ctx, bool reverse, DenseTensor* out); +template +void CummaxKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* out); + template void LogcumsumexpKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/gpu/cum_kernel.cu b/paddle/phi/kernels/gpu/cum_kernel.cu index 40d7f74379fa74..16337e08e70717 100644 --- a/paddle/phi/kernels/gpu/cum_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_kernel.cu @@ -139,6 +139,14 @@ struct LogAddExp { } }; +struct Max { + template + __host__ __device__ __forceinline__ T operator()(const T& a, + const T& b) const { + return std::max(a, b); + } +}; + template struct Identity; @@ -364,6 +372,20 @@ void CumsumKernel(const Context& dev_ctx, dev_ctx, x, axis, flatten, exclusive, reverse, op, out); } +template +void CummaxKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* out) { + using Op = Max; + auto op = Op(); + ScanKernel( + dev_ctx, x, axis, flatten, exclusive, reverse, op, out); +} + template void LogcumsumexpKernel(const Context& dev_ctx, const DenseTensor& x, @@ -390,5 +412,19 @@ PD_REGISTER_KERNEL(cumsum, int, int64_t) {} -PD_REGISTER_KERNEL( - logcumsumexp, GPU, ALL_LAYOUT, phi::LogcumsumexpKernel, float, double) {} +PD_REGISTER_KERNEL(cummax, + GPU, + ALL_LAYOUT, + phi::CummaxKernel, + float, + double, + int16_t, + int, + int64_t) {} + +PD_REGISTER_KERNEL(logcumsumexp, + GPU, + ALL_LAYOUT, + phi::LogcumsumexpKernel, + float, + double) {} diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 6d365622746e3f..e09ffb71365e29 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -3079,6 +3079,7 @@ def cumsum(x, axis=None, dtype=None, name=None): else: return _C_ops.cumsum(x, 'axis', axis, 'flatten', flatten) + # renziji: the static mode part is different, does it need to be modified ??? check_type(x, 'x', (Variable), 'cumsum') locals_var = locals().copy() kwargs = dict() @@ -3089,6 +3090,73 @@ def cumsum(x, axis=None, dtype=None, name=None): return _cum_sum_(**kwargs) +def cummax(x, axis=None, dtype=None, name=None): + """ + The cumulative sum of the elements along a given axis. + + **Note**: + The first element of the result is the same as the first element of the input. + + Args: + x (Tensor): The input tensor needed to be cumsumed. + axis (int, optional): The dimension to accumulate along. -1 means the last dimension. The default (None) is to compute the cumsum over the flattened array. + dtype (str, optional): The data type of the output tensor, can be float32, float64, int32, int64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, the result of cumsum operator. + + Examples: + .. code-block:: python + + import paddle + + data = paddle.arange(12) + data = paddle.reshape(data, (3, 4)) + + y = paddle.cumsum(data) + # [ 0 1 3 6 10 15 21 28 36 45 55 66] + + y = paddle.cumsum(data, axis=0) + # [[ 0 1 2 3] + # [ 4 6 8 10] + # [12 15 18 21]] + + y = paddle.cumsum(data, axis=-1) + # [[ 0 1 3 6] + # [ 4 9 15 22] + # [ 8 17 27 38]] + + y = paddle.cumsum(data, dtype='float64') + print(y.dtype) + # paddle.float64 + """ + if axis is None: + flatten = True + else: + flatten = False + if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype): + x = cast(x, dtype) + + if in_dygraph_mode(): + if axis is None: axis = -1 + return _C_ops.final_state_cummax(x, axis, flatten, False, False) + if _in_legacy_dygraph(): + if axis is None: + return _C_ops.cummax(x, 'flatten', flatten) + else: + return _C_ops.cummax(x, 'axis', axis, 'flatten', flatten) + + check_type(x, 'x', (Variable), 'cummax') + locals_var = locals().copy() + kwargs = dict() + for name, val in locals_var.items(): + if val is not None: + kwargs[name] = val + _cum_sum_ = generate_layer_fn('cummax') + return _cum_sum_(**kwargs) + + def logcumsumexp(x, axis=None, dtype=None, name=None): r""" The logarithm of the cumulative summation of the exponentiation of the elements along a given axis.