From cd76fb51d7b72d495c733e91d770032a328aeef5 Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Thu, 26 Aug 2021 02:29:34 +0000 Subject: [PATCH 01/14] fix github name --- paddle/fluid/operators/math/algorithm.h | 12 +- paddle/fluid/operators/searchsorted_op.cc | 130 +++++++++++++++ paddle/fluid/operators/searchsorted_op.cu | 23 +++ paddle/fluid/operators/searchsorted_op.h | 156 ++++++++++++++++++ python/paddle/__init__.py | 1 + .../tests/unittests/test_searchsorted_op.py | 120 ++++++++++++++ python/paddle/tensor/__init__.py | 3 + python/paddle/tensor/search.py | 92 +++++++++++ 8 files changed, 531 insertions(+), 6 deletions(-) create mode 100644 paddle/fluid/operators/searchsorted_op.cc create mode 100644 paddle/fluid/operators/searchsorted_op.cu create mode 100644 paddle/fluid/operators/searchsorted_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_searchsorted_op.py diff --git a/paddle/fluid/operators/math/algorithm.h b/paddle/fluid/operators/math/algorithm.h index 864cb94cec1e72..72eada950f11e1 100644 --- a/paddle/fluid/operators/math/algorithm.h +++ b/paddle/fluid/operators/math/algorithm.h @@ -1,4 +1,4 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -25,7 +25,7 @@ namespace operators { namespace math { template -HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) { +HOSTDEVICE inline int64_t BinarySearch(const T *x, size_t num, const T &val) { int64_t beg = 0, end = num - 1; while (beg <= end) { auto mid = ((beg + end) >> 1); @@ -39,8 +39,8 @@ HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) { return -1; } -template -HOSTDEVICE inline size_t LowerBound(const T *x, size_t num, const T &val) { +template +HOSTDEVICE inline size_t LowerBound(const T1 *x, size_t num, const T2 &val) { #if defined(__CUDA_ARCH__) || defined(__HIPCC__) // @{ Group LowerBound // The following code is from // https://en.cppreference.com/w/cpp/algorithm/lower_bound @@ -62,8 +62,8 @@ HOSTDEVICE inline size_t LowerBound(const T *x, size_t num, const T &val) { #endif // @} End Group LowerBound } -template -HOSTDEVICE inline size_t UpperBound(const T *x, size_t num, const T &val) { +template +HOSTDEVICE inline size_t UpperBound(const T1 *x, size_t num, const T2 &val) { #if defined(__CUDA_ARCH__) || defined(__HIPCC__) // @{ Group UpperBound // The following code is from // https://en.cppreference.com/w/cpp/algorithm/upper_bound diff --git a/paddle/fluid/operators/searchsorted_op.cc b/paddle/fluid/operators/searchsorted_op.cc new file mode 100644 index 00000000000000..f4fde436f89f62 --- /dev/null +++ b/paddle/fluid/operators/searchsorted_op.cc @@ -0,0 +1,130 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/searchsorted_op.h" + +#include +#include +#include +#include + +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { + +class SearchSortedOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + static bool SearchsortedDimsMatchedBeforeLastDim( + const framework::DDim& sequences_dims, + const framework::DDim& values_dims) { + if (sequences_dims.size() != values_dims.size()) { + return false; + } + const auto& sequences_dims_size = sequences_dims.size(); + for (int64_t dim = 0; dim < sequences_dims_size - 1; ++dim) { + if (sequences_dims[dim] != values_dims[dim]) { + return false; + } + } + return true; + } + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("SortedSequence"), "Input", "SortedSequence", + "searchsorted"); + OP_INOUT_CHECK(ctx->HasInput("Values"), "Input", "Values", "searchsorted"); + + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "searchsorted"); + + auto sequences_dims = ctx->GetInputDim("SortedSequence"); + auto values_dims = ctx->GetInputDim("Values"); + auto out_int32 = ctx->Attrs().Get("out_int32"); + + PADDLE_ENFORCE_EQ( + sequences_dims.size() == 1 || + SearchsortedDimsMatchedBeforeLastDim(sequences_dims, values_dims), + true, + platform::errors::Unavailable( + "The sorted_sequence tensor should be 1 dimension or the first N-1 " + "dimensions of sorted_sequence tensor and input values tensor must " + "match, but we got sorted_sequence tensor ( %s ), and input value " + "tensor ( %s )", + sequences_dims, values_dims)); + + if (out_int32) { + PADDLE_ENFORCE_GT( + sequences_dims[sequences_dims.size() - 1] < + std::numeric_limits::max(), + true, + platform::errors::Unavailable( + "the size of sorted_sequence last dimension should be less than " + "%d but we got %d", + std::numeric_limits::max(), + sequences_dims[sequences_dims.size() - 1])); + } + + ctx->SetOutputDim("Out", values_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "SortedSequence"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class SearchSortedOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("SortedSequence", + "(Tensor), N-D or 1-D tensor, containing monotonically increasing " + "sequence on the innermost dimension."); + AddInput( + "Values", + "(Tensor), N-D tensor or a Scalar containing the search value(s)."); + AddOutput("Out", "(Tensor), The output tensor of searchsorted op."); + AddAttr("out_int32", + "the output tensor is int64_t type if False and int(32bit " + "normally) type if True.") + .SetDefault(false); + AddAttr( + "right", + "corresponding to lower bound if False and upper bound if True") + .SetDefault(false); + + AddComment(R"DOC( + Searchsorted Operator. + + This operator is used to find the indices of the value from the innermost dimension of sorted_sequence + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(searchsorted, ops::SearchSortedOp, ops::SearchSortedOpMaker); + +REGISTER_OP_CPU_KERNEL( + searchsorted, + ops::SearchSortedKernel, + ops::SearchSortedKernel, + ops::SearchSortedKernel, + ops::SearchSortedKernel); diff --git a/paddle/fluid/operators/searchsorted_op.cu b/paddle/fluid/operators/searchsorted_op.cu new file mode 100644 index 00000000000000..4633ab43efba12 --- /dev/null +++ b/paddle/fluid/operators/searchsorted_op.cu @@ -0,0 +1,23 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/searchsorted_op.h" +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + searchsorted, ops::SearchSortedKernel, + ops::SearchSortedKernel, + ops::SearchSortedKernel, + ops::SearchSortedKernel); diff --git a/paddle/fluid/operators/searchsorted_op.h b/paddle/fluid/operators/searchsorted_op.h new file mode 100644 index 00000000000000..9215db90d6c7b7 --- /dev/null +++ b/paddle/fluid/operators/searchsorted_op.h @@ -0,0 +1,156 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/algorithm.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/for_range.h" +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +class GpuAndCpuSearchSortedCompute { + public: + HOSTDEVICE GpuAndCpuSearchSortedCompute(const T1* sequence_data, + const T2* value_data, bool right, + bool is_1d_boundaries, + int64_t val_size, int64_t seq_size, + OutType* out_data) + : sequence_data_(sequence_data), + value_data_(value_data), + right_(right), + is_1d_boundaries_(is_1d_boundaries), + val_size_(val_size), + seq_size_(seq_size), + out_data_(out_data) {} + HOSTDEVICE void operator()(int64_t idx) { + const T2* value_ptr = value_data_ + idx; + const T1* sequence_ptr = is_1d_boundaries_ + ? sequence_data_ + : sequence_data_ + idx / val_size_ * seq_size_; + if (std::isnan(*value_ptr) || std::isinf(*value_ptr)) { + out_data_[idx] = 0; + } else { + if (right_) { + out_data_[idx] = static_cast( + math::UpperBound(sequence_ptr, seq_size_, *value_ptr)); + } else { + out_data_[idx] = static_cast( + math::LowerBound(sequence_ptr, seq_size_, *value_ptr)); + } + } + } + + private: + const T1* sequence_data_; + const T2* value_data_; + bool right_; + bool is_1d_boundaries_; + int64_t val_size_; + int64_t seq_size_; + OutType* out_data_; +}; + +template +class SearchSortedFunctor { + public: + SearchSortedFunctor(const framework::ExecutionContext& context, + const framework::Tensor* sorted_sequence, + const framework::Tensor* value, bool right, + OutType* out_data) + : context_(context), + sorted_sequence_(sorted_sequence), + value_(value), + right_(right), + out_data_(out_data) {} + + template + void apply() { + const T1* sequence_data = sorted_sequence_->data(); + const T2* value_data = value_->data(); + const framework::DDim& seq_dims = sorted_sequence_->dims(); + const framework::DDim& val_dims = value_->dims(); + + bool is_1d_boundaries = seq_dims.size() == 1; + int64_t val_size = val_dims[val_dims.size() - 1]; + int64_t seq_size = seq_dims[seq_dims.size() - 1]; + + auto& dev_ctx = context_.template device_context(); + platform::ForRange for_range(dev_ctx, value_->numel()); + GpuAndCpuSearchSortedCompute + gpu_and_cpu_search_sorted_compute(sequence_data, value_data, right_, + is_1d_boundaries, val_size, seq_size, + out_data_); + for_range(gpu_and_cpu_search_sorted_compute); + } + + private: + const framework::ExecutionContext& context_; + const framework::Tensor* sorted_sequence_; + const framework::Tensor* value_; + bool right_; + OutType* out_data_; +}; + +template +static void VisitDataType(framework::proto::VarType::Type type, + Visitor visitor) { + if (type == framework::proto::VarType::FP32) { + visitor.template apply(); + } else if (type == framework::proto::VarType::FP64) { + visitor.template apply(); + } else if (type == framework::proto::VarType::INT32) { + visitor.template apply(); + } else if (type == framework::proto::VarType::INT64) { + visitor.template apply(); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The given values datatype of searchsorted operators must be float32, " + "float64, int32 or int64, but the recieved values datatype of " + "searchsorted operators is %s", + framework::DataTypeToString(type))); + } +} + +template +class SearchSortedKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* sorted_sequence = context.Input("SortedSequence"); + auto* value = context.Input("Values"); + bool out_int32 = context.Attr("out_int32"); + bool right = context.Attr("right"); + auto* out = context.Output("Out"); + + if (out_int32) { + int* out_data = out->mutable_data(context.GetPlace()); + SearchSortedFunctor functor( + context, sorted_sequence, value, right, out_data); + VisitDataType(value->type(), functor); + } else { + int64_t* out_data = out->mutable_data(context.GetPlace()); + SearchSortedFunctor functor( + context, sorted_sequence, value, right, out_data); + VisitDataType(value->type(), functor); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 1c38d519798666..ba3ee701885c4a 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -231,6 +231,7 @@ from .tensor.search import argmax # noqa: F401 from .tensor.search import argmin # noqa: F401 from .tensor.search import argsort # noqa: F401 +from .tensor.search import searchsorted # noqa: F401 from .tensor.search import masked_select # noqa: F401 from .tensor.search import topk # noqa: F401 from .tensor.search import where # noqa: F401 diff --git a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py new file mode 100644 index 00000000000000..402f9c7480e64e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py @@ -0,0 +1,120 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid import compiler, Program, program_guard + +paddle.enable_static() +from op_test import OpTest + + +class TestSearchSorted(OpTest): + def setUp(self): + + self.op_type = "searchsorted" + self.init_dtype() + self.init_test_case() + + self.inputs = { + 'SortedSequence': self.sorted_sequence, + 'Values': self.values + } + self.attrs = {"out_int32": False, "right": False} + self.side = "right" if self.attrs["right"] else "left" + self.outputs = { + 'Out': np.searchsorted( + self.sorted_sequence, self.values, side=self.side) + } + + def test_check_output(self): + self.check_output() + + def init_test_case(self): + self.sorted_sequence = np.array([1, 3, 5, 7, 9]) + self.values = np.array([[3, 6, 9], [3, 6, 9]]) + self.side = "left" + + def init_dtype(self): + self.dtype = np.float64 + + +class TestSearchSorted_float32(TestSearchSorted): + def init_dtype(self): + self.dtype = np.float32 + + +class TestSearchSorted_int32(TestSearchSorted): + def init_dtype(self): + self.dtype = np.int32 + + +class TestSearchSorted_int64(TestSearchSorted): + def init_dtype(self): + self.dtype = np.int64 + + +class TestSearchSortedAPI(unittest.TestCase): + def init_dtype(self): + self.dtype = np.int64 + + def setUp(self): + self.init_dtype() + self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype(self.dtype) + self.values = np.array([[3, 6, 9], [3, 6, 9]]).astype(self.dtype) + self.place = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def test_static_api(self): + paddle.enable_static() + + def run(place): + with paddle.static.program_guard(paddle.static.Program()): + SortedSequence = paddle.fluid.data( + 'SortedSequence', dtype=self.dtype) + Values = paddle.fluid.data('Values', dtype=self.dtype) + out = np.searchsorted(SortedSequence, Values) + exe = paddle.static.Executor(place) + res = exe.run(feed={ + 'SortedSequence': self.sorted_sequence, + 'Values': self.values + }) + out_ref = np.searchsorted(self.sortedsequence, self.values) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + def run(place): + + paddle.disable_static(place) + SortedSequence = paddle.to_tensor(self.sorted_sequence) + Values = paddle.to_tensor(self.values) + out = paddle.searchsorted(SortedSequence, Values) + out_ref = np.searchsorted(self.sorted_sequence, self.values) + self.assertEqual(np.allclose(out_ref, out.numpy()), True) + paddle.enable_static() + + for place in self.place: + run(place) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index cc20e98006fec4..dfb17b291143ed 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -187,6 +187,7 @@ from .search import argmax # noqa: F401 from .search import argmin # noqa: F401 from .search import argsort # noqa: F401 +from .search import searchsorted # noqa: F401 from .search import topk # noqa: F401 from .search import where # noqa: F401 from .search import index_select # noqa: F401 @@ -251,6 +252,7 @@ 'round_', 'rsqrt', 'rsqrt_', + 'searchsorted' 'scale', 'scale_', 'sign', @@ -353,6 +355,7 @@ 'index_select', 'nonzero', 'sort', + 'searchsorted', 'index_sample', 'mean', 'std', diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 071a26151905a2..b289d382725a19 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -765,3 +765,95 @@ def topk(x, k, axis=None, largest=True, sorted=True, name=None): attrs=attrs) indices.stop_gradient = True return values, indices + + +def searchsorted(sorted_sequence, + values, + out_int32=False, + right=False, + name=None): + """ + This OP is used to find the indices from the innermost dimension of `sorted_sequence`. If the correspoding values in `values` + were inserted before the indices, the order of the corresponding *innermost* dimension within :attr:`sorted_sequence` would + be preserved.Return a new tensor with the same size as `values`. If `right` is False (default),then the left boundary of + `sorted_sequence` is closed. More formally, the returned index satisfies the following rules: + * - :attr:`sorted_sequence` + - :attr:`right` + - *returned index satisfies* + * - 1-D + - False + - ``sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]`` + * - 1-D + - True + - ``sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]`` + * - N-D + - False + - ``sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]`` + * - N-D + - True + - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` + + Args: + sorted_sequence(Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the innermost dimension. The data type can be int32, int64, float32, float64. + values(Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). The data type can be int32, int64, float32, float64. + out_int32(bool, optional): indicate the output data type. The default value is False. + right(bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf). In other words, if False, gets the lower bound index for each value in `values` on the corresponding + innermost dimension of the `sorted_sequence`. If True, gets the upper bound index instead. The default value is False. + name(str, optional):The default value is None. Normally there is no need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + + Returns: + output (Tensor):return the indices from the innermost dimension of sorted_sequence. The output tensor is the same size as values. + + Examples: + + .. code-block:: python + import paddle + sorted_sequence = paddle.to_tensor([[1, 3, 5, 7, 9,11], + [2, 4, 6, 8, 10, 12]],dtype = 'int32') + values = paddle.to_tensor([[3,6,9],[3,6,9]],dtype = 'int32') + out1 = paddle.searchsorted(sorted_sequence,values) + print(out1) + # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [[1, 3, 4], + # [1, 2, 4]]) + + out2=paddle.searchsorted(sorted_sequence,values,right=True) + print(out2) + # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [[2, 3, 5], + # [1, 3, 4]]) + + sorted_sequence_1d= paddle.to_tensor([1,3,5,7,9]) + out3=paddle.searchsorted(sorted_sequence_1d,values) + print(out3) + # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [[1, 3, 4], + # [1, 3, 4]]) + """ + + if in_dygraph_mode(): + return _C_ops.searchsorted(sorted_sequence, values, "out_int32", + out_int32, "right", right) + + check_variable_and_dtype(sorted_sequence, 'SortedSequence', + ['float32', 'float64', 'int32', 'int64'], + 'paddle.searchsorted') + check_variable_and_dtype(values, 'Values', + ['float32', 'float64', 'int32', 'int64'], + 'paddle.searchsorted') + + helper = LayerHelper('searchsorted', **locals()) + out_type = 'int32' if out_int32 else 'int64' + out = helper.create_variable_for_type_inference(dtype=out_type) + helper.append_op( + type='searchsorted', + inputs={'SortedSequence': sorted_sequence, + "Values": values}, + outputs={'Out': out}, + attrs={"out_int32": out_int32, + "right": right}) + + return out From 35b2dc9f98386e8ae8fb14d311ab659fec8aed19 Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Thu, 26 Aug 2021 07:41:30 +0000 Subject: [PATCH 02/14] fix CI error --- paddle/fluid/operators/searchsorted_op.cc | 29 ++++++++++--------- .../tests/unittests/test_searchsorted_op.py | 19 ++++++++---- python/paddle/tensor/__init__.py | 3 +- python/paddle/tensor/search.py | 6 ++-- 4 files changed, 32 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/operators/searchsorted_op.cc b/paddle/fluid/operators/searchsorted_op.cc index f4fde436f89f62..a9eaed927159f8 100644 --- a/paddle/fluid/operators/searchsorted_op.cc +++ b/paddle/fluid/operators/searchsorted_op.cc @@ -53,22 +53,23 @@ class SearchSortedOp : public framework::OperatorWithKernel { auto values_dims = ctx->GetInputDim("Values"); auto out_int32 = ctx->Attrs().Get("out_int32"); - PADDLE_ENFORCE_EQ( - sequences_dims.size() == 1 || - SearchsortedDimsMatchedBeforeLastDim(sequences_dims, values_dims), - true, - platform::errors::Unavailable( - "The sorted_sequence tensor should be 1 dimension or the first N-1 " - "dimensions of sorted_sequence tensor and input values tensor must " - "match, but we got sorted_sequence tensor ( %s ), and input value " - "tensor ( %s )", - sequences_dims, values_dims)); + if (sequences_dims.size() != 1) + PADDLE_ENFORCE_EQ( + SearchsortedDimsMatchedBeforeLastDim(sequences_dims, values_dims), + true, + platform::errors::Unavailable("The sorted_sequence tensor should be " + "1 dimension or the first N-1 " + "dimensions of sorted_sequence tensor " + "and input values tensor must " + "match, but we got sorted_sequence " + "tensor ( %s ), and input value " + "tensor ( %s )", + sequences_dims, values_dims)); if (out_int32) { - PADDLE_ENFORCE_GT( - sequences_dims[sequences_dims.size() - 1] < - std::numeric_limits::max(), - true, + PADDLE_ENFORCE_LT( + sequences_dims[sequences_dims.size() - 1], + std::numeric_limits::max(), platform::errors::Unavailable( "the size of sorted_sequence last dimension should be less than " "%d but we got %d", diff --git a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py index 402f9c7480e64e..a13214a1348096 100644 --- a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py +++ b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py @@ -88,19 +88,26 @@ def test_static_api(self): def run(place): with paddle.static.program_guard(paddle.static.Program()): - SortedSequence = paddle.fluid.data( - 'SortedSequence', dtype=self.dtype) - Values = paddle.fluid.data('Values', dtype=self.dtype) - out = np.searchsorted(SortedSequence, Values) + sorted_sequence = paddle.static.data( + 'SortedSequence', + shape=self.sorted_sequence.shape, + dtype=self.dtype) + values = paddle.static.data( + 'Values', shape=self.values.shape, dtype=self.dtype) + out = paddle.searchsorted(sorted_sequence, values) exe = paddle.static.Executor(place) res = exe.run(feed={ 'SortedSequence': self.sorted_sequence, 'Values': self.values - }) - out_ref = np.searchsorted(self.sortedsequence, self.values) + }, + fetch_list=out) + out_ref = np.searchsorted(self.sorted_sequence, self.values) for r in res: self.assertEqual(np.allclose(out_ref, r), True) + for place in self.place: + run(place) + def test_dygraph_api(self): def run(place): diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index dfb17b291143ed..de9d87ae51fa76 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -252,7 +252,7 @@ 'round_', 'rsqrt', 'rsqrt_', - 'searchsorted' + 'searchsorted', 'scale', 'scale_', 'sign', @@ -355,7 +355,6 @@ 'index_select', 'nonzero', 'sort', - 'searchsorted', 'index_sample', 'mean', 'std', diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index b289d382725a19..43f0e98c765ad3 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -820,14 +820,14 @@ def searchsorted(sorted_sequence, # [[1, 3, 4], # [1, 2, 4]]) - out2=paddle.searchsorted(sorted_sequence,values,right=True) + out2 = paddle.searchsorted(sorted_sequence,values,right=True) print(out2) # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, # [[2, 3, 5], # [1, 3, 4]]) - sorted_sequence_1d= paddle.to_tensor([1,3,5,7,9]) - out3=paddle.searchsorted(sorted_sequence_1d,values) + sorted_sequence_1d = paddle.to_tensor([1,3,5,7,9]) + out3 = paddle.searchsorted(sorted_sequence_1d,values) print(out3) # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, # [[1, 3, 4], From f154c93ef96bf11ad7e672c5425cb01d43fd65fd Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Fri, 27 Aug 2021 03:59:27 +0000 Subject: [PATCH 03/14] fix review and CI error --- paddle/fluid/operators/math/algorithm.h | 2 +- paddle/fluid/operators/searchsorted_op.cc | 10 +- paddle/fluid/operators/searchsorted_op.h | 12 +- .../tests/unittests/test_searchsorted_op.py | 14 +-- python/paddle/tensor/__init__.py | 1 - python/paddle/tensor/search.py | 116 +++++++++--------- 6 files changed, 78 insertions(+), 77 deletions(-) diff --git a/paddle/fluid/operators/math/algorithm.h b/paddle/fluid/operators/math/algorithm.h index 72eada950f11e1..346c693a22d852 100644 --- a/paddle/fluid/operators/math/algorithm.h +++ b/paddle/fluid/operators/math/algorithm.h @@ -25,7 +25,7 @@ namespace operators { namespace math { template -HOSTDEVICE inline int64_t BinarySearch(const T *x, size_t num, const T &val) { +HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) { int64_t beg = 0, end = num - 1; while (beg <= end) { auto mid = ((beg + end) >> 1); diff --git a/paddle/fluid/operators/searchsorted_op.cc b/paddle/fluid/operators/searchsorted_op.cc index a9eaed927159f8..83b51c73fc569f 100644 --- a/paddle/fluid/operators/searchsorted_op.cc +++ b/paddle/fluid/operators/searchsorted_op.cc @@ -14,11 +14,6 @@ #include "paddle/fluid/operators/searchsorted_op.h" -#include -#include -#include -#include - #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -53,7 +48,7 @@ class SearchSortedOp : public framework::OperatorWithKernel { auto values_dims = ctx->GetInputDim("Values"); auto out_int32 = ctx->Attrs().Get("out_int32"); - if (sequences_dims.size() != 1) + if (sequences_dims.size() != 1) { PADDLE_ENFORCE_EQ( SearchsortedDimsMatchedBeforeLastDim(sequences_dims, values_dims), true, @@ -62,9 +57,10 @@ class SearchSortedOp : public framework::OperatorWithKernel { "dimensions of sorted_sequence tensor " "and input values tensor must " "match, but we got sorted_sequence " - "tensor ( %s ), and input value " + "tensor ( %s ), and input values " "tensor ( %s )", sequences_dims, values_dims)); + } if (out_int32) { PADDLE_ENFORCE_LT( diff --git a/paddle/fluid/operators/searchsorted_op.h b/paddle/fluid/operators/searchsorted_op.h index 9215db90d6c7b7..714d821ae94f44 100644 --- a/paddle/fluid/operators/searchsorted_op.h +++ b/paddle/fluid/operators/searchsorted_op.h @@ -27,6 +27,16 @@ using Tensor = framework::Tensor; template class GpuAndCpuSearchSortedCompute { public: + static HOSTDEVICE bool IsNan(float x) { return std::isnan(x); } + static HOSTDEVICE bool IsNan(double x) { return std::isnan(x); } + static HOSTDEVICE bool IsNan(int x) { return false; } + static HOSTDEVICE bool IsNan(int64_t x) { return false; } + + static HOSTDEVICE bool IsInf(float x) { return std::isinf(x); } + static HOSTDEVICE bool IsInf(double x) { return std::isinf(x); } + static HOSTDEVICE bool IsInf(int x) { return false; } + static HOSTDEVICE bool IsInf(int64_t x) { return false; } + HOSTDEVICE GpuAndCpuSearchSortedCompute(const T1* sequence_data, const T2* value_data, bool right, bool is_1d_boundaries, @@ -44,7 +54,7 @@ class GpuAndCpuSearchSortedCompute { const T1* sequence_ptr = is_1d_boundaries_ ? sequence_data_ : sequence_data_ + idx / val_size_ * seq_size_; - if (std::isnan(*value_ptr) || std::isinf(*value_ptr)) { + if (IsNan(*value_ptr) || IsInf(*value_ptr)) { out_data_[idx] = 0; } else { if (right_) { diff --git a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py index a13214a1348096..a23c29e8098ec2 100644 --- a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py +++ b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py @@ -16,11 +16,8 @@ import numpy as np from op_test import OpTest import paddle -import paddle.nn as nn -import paddle.nn.functional as F -import paddle.fluid as fluid import paddle.fluid.core as core -from paddle.fluid import compiler, Program, program_guard +from paddle.fluid import Program, program_guard paddle.enable_static() from op_test import OpTest @@ -102,8 +99,7 @@ def run(place): }, fetch_list=out) out_ref = np.searchsorted(self.sorted_sequence, self.values) - for r in res: - self.assertEqual(np.allclose(out_ref, r), True) + self.assertTrue(np.allclose(out_ref, res)) for place in self.place: run(place) @@ -112,9 +108,9 @@ def test_dygraph_api(self): def run(place): paddle.disable_static(place) - SortedSequence = paddle.to_tensor(self.sorted_sequence) - Values = paddle.to_tensor(self.values) - out = paddle.searchsorted(SortedSequence, Values) + sorted_sequence = paddle.to_tensor(self.sorted_sequence) + values = paddle.to_tensor(self.values) + out = paddle.searchsorted(sorted_sequence, values) out_ref = np.searchsorted(self.sorted_sequence, self.values) self.assertEqual(np.allclose(out_ref, out.numpy()), True) paddle.enable_static() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index de9d87ae51fa76..193a696dbb156a 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -252,7 +252,6 @@ 'round_', 'rsqrt', 'rsqrt_', - 'searchsorted', 'scale', 'scale_', 'sign', diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 43f0e98c765ad3..8158cb1a6bbb99 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -773,65 +773,65 @@ def searchsorted(sorted_sequence, right=False, name=None): """ - This OP is used to find the indices from the innermost dimension of `sorted_sequence`. If the correspoding values in `values` - were inserted before the indices, the order of the corresponding *innermost* dimension within :attr:`sorted_sequence` would - be preserved.Return a new tensor with the same size as `values`. If `right` is False (default),then the left boundary of - `sorted_sequence` is closed. More formally, the returned index satisfies the following rules: - * - :attr:`sorted_sequence` - - :attr:`right` - - *returned index satisfies* - * - 1-D - - False - - ``sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]`` - * - 1-D - - True - - ``sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]`` - * - N-D - - False - - ``sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]`` - * - N-D - - True - - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` + This OP is used to find the indices from the innermost dimension of `sorted_sequence`. If the correspoding values in `values` + were inserted before the indices, the order of the corresponding *innermost* dimension within :attr:`sorted_sequence` would + be preserved.Return a new tensor with the same size as `values`. If `right` is False (default),then the left boundary of + `sorted_sequence` is closed. More formally, the returned index satisfies the following rules: + * - :attr:`sorted_sequence` + - :attr:`right` + - *returned index satisfies* + * - 1-D + - False + - ``sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]`` + * - 1-D + - True + - ``sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]`` + * - N-D + - False + - ``sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]`` + * - N-D + - True + - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` + + Args: + sorted_sequence(Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the innermost dimension. The data type can be int32, int64, float32, float64. + values(Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). The data type can be int32, int64, float32, float64. + out_int32(bool, optional): indicate the output data type. The default value is False. + right(bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf). In other words, if False, gets the lower bound index for each value in `values` on the corresponding + innermost dimension of the `sorted_sequence`. If True, gets the upper bound index instead. The default value is False. + name(str, optional):The default value is None. Normally there is no need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + + Returns: + output (Tensor):return the indices from the innermost dimension of sorted_sequence. The output tensor is the same size as values. + + Examples: - Args: - sorted_sequence(Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the innermost dimension. The data type can be int32, int64, float32, float64. - values(Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). The data type can be int32, int64, float32, float64. - out_int32(bool, optional): indicate the output data type. The default value is False. - right(bool, optional): if False, return the first suitable location that is found. If True, return the - last such index. If no suitable index found, return 0 for non-numerical value - (eg. nan, inf). In other words, if False, gets the lower bound index for each value in `values` on the corresponding - innermost dimension of the `sorted_sequence`. If True, gets the upper bound index instead. The default value is False. - name(str, optional):The default value is None. Normally there is no need for user to set this property. For more information, please - refer to :ref:`api_guide_Name`. - - Returns: - output (Tensor):return the indices from the innermost dimension of sorted_sequence. The output tensor is the same size as values. - - Examples: - - .. code-block:: python - import paddle - sorted_sequence = paddle.to_tensor([[1, 3, 5, 7, 9,11], - [2, 4, 6, 8, 10, 12]],dtype = 'int32') - values = paddle.to_tensor([[3,6,9],[3,6,9]],dtype = 'int32') - out1 = paddle.searchsorted(sorted_sequence,values) - print(out1) - # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, - # [[1, 3, 4], - # [1, 2, 4]]) - - out2 = paddle.searchsorted(sorted_sequence,values,right=True) - print(out2) - # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, - # [[2, 3, 5], - # [1, 3, 4]]) - - sorted_sequence_1d = paddle.to_tensor([1,3,5,7,9]) - out3 = paddle.searchsorted(sorted_sequence_1d,values) - print(out3) - # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, - # [[1, 3, 4], - # [1, 3, 4]]) + .. code-block:: python + import paddle + sorted_sequence = paddle.to_tensor([[1, 3, 5, 7, 9,11], + [2, 4, 6, 8, 10, 12]],dtype = 'int32') + values = paddle.to_tensor([[3,6,9],[3,6,9]],dtype = 'int32') + out1 = paddle.searchsorted(sorted_sequence,values) + print(out1) + # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [[1, 3, 4], + # [1, 2, 4]]) + + out2 = paddle.searchsorted(sorted_sequence,values,right=True) + print(out2) + # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [[2, 3, 5], + # [1, 3, 4]]) + + sorted_sequence_1d = paddle.to_tensor([1,3,5,7,9]) + out3 = paddle.searchsorted(sorted_sequence_1d,values) + print(out3) + # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [[1, 3, 4], + # [1, 3, 4]]) """ if in_dygraph_mode(): From 7f44f54931b1d3c69c6f4391c74ad963c2cd6b40 Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Mon, 30 Aug 2021 05:56:12 +0000 Subject: [PATCH 04/14] fix inf,nan error and modify unittest samples --- paddle/fluid/operators/searchsorted_op.h | 15 ++-- .../tests/unittests/test_searchsorted_op.py | 80 ++++++++++++++----- 2 files changed, 68 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/operators/searchsorted_op.h b/paddle/fluid/operators/searchsorted_op.h index 714d821ae94f44..b503e1f92af947 100644 --- a/paddle/fluid/operators/searchsorted_op.h +++ b/paddle/fluid/operators/searchsorted_op.h @@ -14,12 +14,15 @@ #pragma once +#include + #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/for_range.h" + namespace paddle { namespace operators { using Tensor = framework::Tensor; @@ -27,13 +30,13 @@ using Tensor = framework::Tensor; template class GpuAndCpuSearchSortedCompute { public: - static HOSTDEVICE bool IsNan(float x) { return std::isnan(x); } - static HOSTDEVICE bool IsNan(double x) { return std::isnan(x); } + static HOSTDEVICE bool IsNan(float x) { return ::isnan(x); } + static HOSTDEVICE bool IsNan(double x) { return ::isnan(x); } static HOSTDEVICE bool IsNan(int x) { return false; } static HOSTDEVICE bool IsNan(int64_t x) { return false; } - static HOSTDEVICE bool IsInf(float x) { return std::isinf(x); } - static HOSTDEVICE bool IsInf(double x) { return std::isinf(x); } + static HOSTDEVICE bool IsInf(float x) { return ::isinf(x); } + static HOSTDEVICE bool IsInf(double x) { return ::isinf(x); } static HOSTDEVICE bool IsInf(int x) { return false; } static HOSTDEVICE bool IsInf(int64_t x) { return false; } @@ -54,8 +57,8 @@ class GpuAndCpuSearchSortedCompute { const T1* sequence_ptr = is_1d_boundaries_ ? sequence_data_ : sequence_data_ + idx / val_size_ * seq_size_; - if (IsNan(*value_ptr) || IsInf(*value_ptr)) { - out_data_[idx] = 0; + if (IsInf(*value_ptr) || IsNan(*value_ptr)) { + out_data_[idx] = seq_size_; } else { if (right_) { out_data_[idx] = static_cast( diff --git a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py index a23c29e8098ec2..2d5d3033ab23bc 100644 --- a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py +++ b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py @@ -16,8 +16,8 @@ import numpy as np from op_test import OpTest import paddle +import paddle.fluid as fluid import paddle.fluid.core as core -from paddle.fluid import Program, program_guard paddle.enable_static() from op_test import OpTest @@ -27,7 +27,6 @@ class TestSearchSorted(OpTest): def setUp(self): self.op_type = "searchsorted" - self.init_dtype() self.init_test_case() self.inputs = { @@ -35,7 +34,7 @@ def setUp(self): 'Values': self.values } self.attrs = {"out_int32": False, "right": False} - self.side = "right" if self.attrs["right"] else "left" + self.attrs["right"] = True if self.side == 'right' else False self.outputs = { 'Out': np.searchsorted( self.sorted_sequence, self.values, side=self.side) @@ -49,33 +48,47 @@ def init_test_case(self): self.values = np.array([[3, 6, 9], [3, 6, 9]]) self.side = "left" - def init_dtype(self): - self.dtype = np.float64 + +# sorted_sequence is one dimension +class TestSearchSortedOp1(TestSearchSorted): + def init_test_case(self): + self.sorted_sequence = np.array([1, 3, 5, 7, 9]) + self.values = np.array([[3, 6, 9], [3, 6, 9]]) + self.side = "right" -class TestSearchSorted_float32(TestSearchSorted): - def init_dtype(self): - self.dtype = np.float32 +class TestSearchSortedOp2(TestSearchSorted): + def init_test_case(self): + self.sorted_sequence = np.array([1, 3, 5, 7, 9]) + self.values = np.array([[3, 6, 9], [3, 6, 9]]) + self.side = "left" -class TestSearchSorted_int32(TestSearchSorted): - def init_dtype(self): - self.dtype = np.int32 +# if the element of values is nan +class TestSearchSortedOp3(TestSearchSorted): + def init_test_case(self): + self.sorted_sequence = np.array([1, 3, 5, 7, 9]) + self.values = np.array( + [[np.nan, np.nan, np.nan], [3, 6, 9]]).astype("float64") + self.side = "left" -class TestSearchSorted_int64(TestSearchSorted): - def init_dtype(self): - self.dtype = np.int64 +# if the element of values is inf +class TestSearchSortedOp4(TestSearchSorted): + def init_test_case(self): + self.sorted_sequence = np.array([1, 3, 5, 7, 9]) + self.values = np.array( + [[np.inf, np.inf, np.inf], [3, 6, 9]]).astype("float64") + self.side = "left" class TestSearchSortedAPI(unittest.TestCase): - def init_dtype(self): - self.dtype = np.int64 + def init_test_case(self): + self.sorted_sequence = np.array([2, 4, 6, 8, 10]).astype("float64") + self.values = np.array([[3, 6, 9], [3, 6, 9]]).astype("float64") def setUp(self): - self.init_dtype() - self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype(self.dtype) - self.values = np.array([[3, 6, 9], [3, 6, 9]]).astype(self.dtype) + self.init_test_case() self.place = [paddle.CPUPlace()] if core.is_compiled_with_cuda(): self.place.append(paddle.CUDAPlace(0)) @@ -88,9 +101,9 @@ def run(place): sorted_sequence = paddle.static.data( 'SortedSequence', shape=self.sorted_sequence.shape, - dtype=self.dtype) + dtype="float64") values = paddle.static.data( - 'Values', shape=self.values.shape, dtype=self.dtype) + 'Values', shape=self.values.shape, dtype="float64") out = paddle.searchsorted(sorted_sequence, values) exe = paddle.static.Executor(place) res = exe.run(feed={ @@ -119,5 +132,30 @@ def run(place): run(place) +class TestSearchSortedError(unittest.TestCase): + def test_error_api(self): + paddle.enable_static() + + def test_sortedsequence_values_dim_error(): + with paddle.static.program_guard(paddle.static.Program()): + sorted_sequence = paddle.static.data( + 'SortedSequence', shape=[2, 2, 3], dtype="float64") + values = paddle.static.data( + 'Values', shape=[2, 5], dtype="float64") + out = paddle.searchsorted(sorted_sequence, values) + + self.assertRaises(RuntimeError, test_sortedsequence_values_dim_error) + + def test_sortedsequence_values_type_error(): + with paddle.static.program_guard(paddle.static.Program()): + sorted_sequence = paddle.static.data( + 'SortedSequence', shape=[2, 3], dtype="int16") + values = paddle.static.data( + 'Values', shape=[2, 5], dtype="float64") + out = paddle.searchsorted(sorted_sequence, values) + + self.assertRaises(TypeError, test_sortedsequence_values_type_error) + + if __name__ == '__main__': unittest.main() From b28edf3028b42232ca7926c01f62f248f4a7bd4c Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Mon, 30 Aug 2021 10:13:37 +0000 Subject: [PATCH 05/14] add unittest samples --- .../tests/unittests/test_searchsorted_op.py | 56 ++++++++++++++----- 1 file changed, 42 insertions(+), 14 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py index 2d5d3033ab23bc..7a8fba582a53aa 100644 --- a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py +++ b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py @@ -44,39 +44,36 @@ def test_check_output(self): self.check_output() def init_test_case(self): - self.sorted_sequence = np.array([1, 3, 5, 7, 9]) - self.values = np.array([[3, 6, 9], [3, 6, 9]]) + self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("float32") + self.values = np.array([[3, 6, 9], [3, 6, 9]]).astype("float32") self.side = "left" -# sorted_sequence is one dimension class TestSearchSortedOp1(TestSearchSorted): def init_test_case(self): - self.sorted_sequence = np.array([1, 3, 5, 7, 9]) - self.values = np.array([[3, 6, 9], [3, 6, 9]]) + self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("int32") + self.values = np.array([[3, 6, 9], [3, 6, 9]]).astype("int32") self.side = "right" class TestSearchSortedOp2(TestSearchSorted): def init_test_case(self): - self.sorted_sequence = np.array([1, 3, 5, 7, 9]) - self.values = np.array([[3, 6, 9], [3, 6, 9]]) + self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("int64") + self.values = np.array([[3, 6, 9], [3, 6, 9]]).astype("int64") self.side = "left" -# if the element of values is nan class TestSearchSortedOp3(TestSearchSorted): def init_test_case(self): - self.sorted_sequence = np.array([1, 3, 5, 7, 9]) + self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("float64") self.values = np.array( [[np.nan, np.nan, np.nan], [3, 6, 9]]).astype("float64") self.side = "left" -# if the element of values is inf class TestSearchSortedOp4(TestSearchSorted): def init_test_case(self): - self.sorted_sequence = np.array([1, 3, 5, 7, 9]) + self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("float64") self.values = np.array( [[np.inf, np.inf, np.inf], [3, 6, 9]]).astype("float64") self.side = "left" @@ -131,12 +128,19 @@ def run(place): for place in self.place: run(place) + def test_out_int32(self): + paddle.disable_static() + sorted_sequence = paddle.to_tensor(self.sorted_sequence) + values = paddle.to_tensor(self.values) + out = paddle.searchsorted(sorted_sequence, values, out_int32=True) + self.assertTrue(out.type, 'int32') + class TestSearchSortedError(unittest.TestCase): def test_error_api(self): paddle.enable_static() - def test_sortedsequence_values_dim_error(): + def test_searchsorted_dims_matched_before_lastdim_error1(): with paddle.static.program_guard(paddle.static.Program()): sorted_sequence = paddle.static.data( 'SortedSequence', shape=[2, 2, 3], dtype="float64") @@ -144,14 +148,38 @@ def test_sortedsequence_values_dim_error(): 'Values', shape=[2, 5], dtype="float64") out = paddle.searchsorted(sorted_sequence, values) - self.assertRaises(RuntimeError, test_sortedsequence_values_dim_error) + self.assertRaises(RuntimeError, + test_searchsorted_dims_matched_before_lastdim_error1) + + def test_searchsorted_dims_matched_before_lastdim_error2(): + with paddle.static.program_guard(paddle.static.Program()): + sorted_sequence = paddle.static.data( + 'SortedSequence', shape=[2, 2, 3], dtype="float64") + values = paddle.static.data( + 'Values', shape=[2, 3, 5], dtype="float64") + out = paddle.searchsorted(sorted_sequence, values) + + self.assertRaises(RuntimeError, + test_searchsorted_dims_matched_before_lastdim_error2) + + def test_searchsorted_sortedsequence_size_error(): + with paddle.static.program_guard(paddle.static.Program()): + sorted_sequence = paddle.static.data( + 'SortedSequence', shape=[2, 2, pow(2, 34)], dtype="float64") + values = paddle.static.data( + 'Values', shape=[2, 2, 5], dtype="float64") + out = paddle.searchsorted( + sorted_sequence, values, out_int32=True) + + self.assertRaises(RuntimeError, + test_searchsorted_sortedsequence_size_error) def test_sortedsequence_values_type_error(): with paddle.static.program_guard(paddle.static.Program()): sorted_sequence = paddle.static.data( 'SortedSequence', shape=[2, 3], dtype="int16") values = paddle.static.data( - 'Values', shape=[2, 5], dtype="float64") + 'Values', shape=[2, 5], dtype="int16") out = paddle.searchsorted(sorted_sequence, values) self.assertRaises(TypeError, test_sortedsequence_values_type_error) From cc92a717c39657067cc9003e9854256b6feb4539 Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Mon, 30 Aug 2021 13:56:22 +0000 Subject: [PATCH 06/14] add unittest samples --- python/paddle/fluid/tests/unittests/test_searchsorted_op.py | 2 +- python/paddle/tensor/search.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py index 7a8fba582a53aa..10837e44d553f2 100644 --- a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py +++ b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py @@ -120,7 +120,7 @@ def run(place): paddle.disable_static(place) sorted_sequence = paddle.to_tensor(self.sorted_sequence) values = paddle.to_tensor(self.values) - out = paddle.searchsorted(sorted_sequence, values) + out = paddle.searchsorted(sorted_sequence, values, right=True) out_ref = np.searchsorted(self.sorted_sequence, self.values) self.assertEqual(np.allclose(out_ref, out.numpy()), True) paddle.enable_static() diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 8158cb1a6bbb99..a441a174b601f0 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -810,6 +810,7 @@ def searchsorted(sorted_sequence, Examples: .. code-block:: python + import paddle sorted_sequence = paddle.to_tensor([[1, 3, 5, 7, 9,11], [2, 4, 6, 8, 10, 12]],dtype = 'int32') From 2f49f8b9869c2c6cc22f9b4016032f6e4ab7e963 Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Tue, 31 Aug 2021 02:42:20 +0000 Subject: [PATCH 07/14] fix unittest error --- python/paddle/fluid/tests/unittests/test_searchsorted_op.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py index 10837e44d553f2..bd59860e740b21 100644 --- a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py +++ b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py @@ -121,7 +121,8 @@ def run(place): sorted_sequence = paddle.to_tensor(self.sorted_sequence) values = paddle.to_tensor(self.values) out = paddle.searchsorted(sorted_sequence, values, right=True) - out_ref = np.searchsorted(self.sorted_sequence, self.values) + out_ref = np.searchsorted( + self.sorted_sequence, self.values, side='right') self.assertEqual(np.allclose(out_ref, out.numpy()), True) paddle.enable_static() From 907e71fcbf99ecbed038f7457c356b7646c8e839 Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Wed, 1 Sep 2021 09:48:08 +0000 Subject: [PATCH 08/14] test=document_fix --- .../tests/unittests/test_searchsorted_op.py | 2 +- python/paddle/tensor/search.py | 117 +++++++++--------- 2 files changed, 60 insertions(+), 59 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py index bd59860e740b21..13434bdd4f4d8d 100644 --- a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py +++ b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py @@ -76,7 +76,7 @@ def init_test_case(self): self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("float64") self.values = np.array( [[np.inf, np.inf, np.inf], [3, 6, 9]]).astype("float64") - self.side = "left" + self.side = "right" class TestSearchSortedAPI(unittest.TestCase): diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index a441a174b601f0..3aeccbe06fc95a 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -773,66 +773,67 @@ def searchsorted(sorted_sequence, right=False, name=None): """ - This OP is used to find the indices from the innermost dimension of `sorted_sequence`. If the correspoding values in `values` - were inserted before the indices, the order of the corresponding *innermost* dimension within :attr:`sorted_sequence` would - be preserved.Return a new tensor with the same size as `values`. If `right` is False (default),then the left boundary of - `sorted_sequence` is closed. More formally, the returned index satisfies the following rules: - * - :attr:`sorted_sequence` - - :attr:`right` - - *returned index satisfies* - * - 1-D - - False - - ``sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]`` - * - 1-D - - True - - ``sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]`` - * - N-D - - False - - ``sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]`` - * - N-D - - True - - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` - - Args: - sorted_sequence(Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the innermost dimension. The data type can be int32, int64, float32, float64. - values(Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). The data type can be int32, int64, float32, float64. - out_int32(bool, optional): indicate the output data type. The default value is False. - right(bool, optional): if False, return the first suitable location that is found. If True, return the - last such index. If no suitable index found, return 0 for non-numerical value - (eg. nan, inf). In other words, if False, gets the lower bound index for each value in `values` on the corresponding - innermost dimension of the `sorted_sequence`. If True, gets the upper bound index instead. The default value is False. - name(str, optional):The default value is None. Normally there is no need for user to set this property. For more information, please - refer to :ref:`api_guide_Name`. - - Returns: - output (Tensor):return the indices from the innermost dimension of sorted_sequence. The output tensor is the same size as values. - - Examples: + This OP is used to find the indices from the innermost dimension of `sorted_sequence`. If the correspoding values in `values` + were inserted before the indices, the order of the corresponding innermost dimension within :`sorted_sequence` would + be preserved. Return a new tensor with the same size as `values`. If `right` is False (default),then the left boundary of + `sorted_sequence` is closed. More formally, the returned index satisfies the following rules: + * - `sorted_sequence` + - `right` + - returned index satisfies + * - 1-D + - False + - ``sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]`` + * - 1-D + - True + - ``sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]`` + * - N-D + - False + - ``sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]`` + * - N-D + - True + - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` + + Args: + sorted_sequence(Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the innermost dimension. The data type can be int32, int64, float32, float64. + values(Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). The data type can be int32, int64, float32, float64. + out_int32(bool, optional): indicate the output data type. The default value is False. + right(bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf). In other words, if False, gets the lower bound index for each value in `values` on the corresponding + innermost dimension of the `sorted_sequence`. If True, gets the upper bound index instead. The default value is False. + name(str, optional):The default value is None. Normally there is no need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + + Returns: + output (Tensor): return the indices from the innermost dimension of sorted_sequence. The output tensor is the same size as values. + + Examples: .. code-block:: python - - import paddle - sorted_sequence = paddle.to_tensor([[1, 3, 5, 7, 9,11], - [2, 4, 6, 8, 10, 12]],dtype = 'int32') - values = paddle.to_tensor([[3,6,9],[3,6,9]],dtype = 'int32') - out1 = paddle.searchsorted(sorted_sequence,values) - print(out1) - # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, - # [[1, 3, 4], - # [1, 2, 4]]) - - out2 = paddle.searchsorted(sorted_sequence,values,right=True) - print(out2) - # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, - # [[2, 3, 5], - # [1, 3, 4]]) - - sorted_sequence_1d = paddle.to_tensor([1,3,5,7,9]) - out3 = paddle.searchsorted(sorted_sequence_1d,values) - print(out3) - # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, - # [[1, 3, 4], - # [1, 3, 4]]) + + import paddle + + sorted_sequence = paddle.to_tensor([[1, 3, 5, 7, 9,11], + [2, 4, 6, 8, 10, 12]],dtype = 'int32') + values = paddle.to_tensor([[3,6,9],[3,6,9]],dtype = 'int32') + out1 = paddle.searchsorted(sorted_sequence,values) + print(out1) + # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [[1, 3, 4], + # [1, 2, 4]]) + + out2 = paddle.searchsorted(sorted_sequence,values,right=True) + print(out2) + # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [[2, 3, 5], + # [1, 3, 4]]) + + sorted_sequence_1d = paddle.to_tensor([1,3,5,7,9]) + out3 = paddle.searchsorted(sorted_sequence_1d,values) + print(out3) + # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [[1, 3, 4], + # [1, 3, 4]]) """ if in_dygraph_mode(): From 32bdd22435a7c2d90388a3cfd4e0a1cb4ed6686b Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Wed, 1 Sep 2021 10:15:27 +0000 Subject: [PATCH 09/14] test=document_fix --- python/paddle/tensor/search.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 3aeccbe06fc95a..c3670c6a55628e 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -794,18 +794,18 @@ def searchsorted(sorted_sequence, - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` Args: - sorted_sequence(Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the innermost dimension. The data type can be int32, int64, float32, float64. - values(Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). The data type can be int32, int64, float32, float64. - out_int32(bool, optional): indicate the output data type. The default value is False. - right(bool, optional): if False, return the first suitable location that is found. If True, return the - last such index. If no suitable index found, return 0 for non-numerical value - (eg. nan, inf). In other words, if False, gets the lower bound index for each value in `values` on the corresponding - innermost dimension of the `sorted_sequence`. If True, gets the upper bound index instead. The default value is False. - name(str, optional):The default value is None. Normally there is no need for user to set this property. For more information, please - refer to :ref:`api_guide_Name`. + sorted_sequence(Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the innermost dimension. The data type can be int32, int64, float32, float64. + values(Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). The data type can be int32, int64, float32, float64. + out_int32(bool, optional): indicate the output data type. The default value is False. + right(bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf). In other words, if False, gets the lower bound index for each value in `values` on the corresponding + innermost dimension of the `sorted_sequence`. If True, gets the upper bound index instead. The default value is False. + name(str, optional):The default value is None. Normally there is no need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. Returns: - output (Tensor): return the indices from the innermost dimension of sorted_sequence. The output tensor is the same size as values. + output (Tensor): return the indices from the innermost dimension of sorted_sequence. The output tensor is the same size as values. Examples: From f21b24ab1b59cfac65882a8a8ac397a456855e17 Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Wed, 1 Sep 2021 14:26:07 +0000 Subject: [PATCH 10/14] modify doc and add unittest samples --- paddle/fluid/operators/searchsorted_op.cc | 43 ++++++----- paddle/fluid/operators/searchsorted_op.h | 7 +- .../tests/unittests/test_searchsorted_op.py | 8 ++ python/paddle/tensor/search.py | 76 +++++++------------ 4 files changed, 62 insertions(+), 72 deletions(-) diff --git a/paddle/fluid/operators/searchsorted_op.cc b/paddle/fluid/operators/searchsorted_op.cc index 83b51c73fc569f..8249b28c73247b 100644 --- a/paddle/fluid/operators/searchsorted_op.cc +++ b/paddle/fluid/operators/searchsorted_op.cc @@ -52,14 +52,16 @@ class SearchSortedOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( SearchsortedDimsMatchedBeforeLastDim(sequences_dims, values_dims), true, - platform::errors::Unavailable("The sorted_sequence tensor should be " - "1 dimension or the first N-1 " - "dimensions of sorted_sequence tensor " - "and input values tensor must " - "match, but we got sorted_sequence " - "tensor ( %s ), and input values " - "tensor ( %s )", - sequences_dims, values_dims)); + platform::errors::Unavailable( + "The dimensions of sorted_sequence tensor ( %s ) and " + "values tensor ( %s ) can not match. Because the input " + "sorted_sequence tensor must be " + "1 dimension or the first N-1 " + "dimensions of sorted_sequence tensor " + "and input values tensor must " + "match. Please input appropriate sorted_sequence and values " + "again!", + sequences_dims, values_dims)); } if (out_int32) { @@ -67,10 +69,12 @@ class SearchSortedOp : public framework::OperatorWithKernel { sequences_dims[sequences_dims.size() - 1], std::numeric_limits::max(), platform::errors::Unavailable( - "the size of sorted_sequence last dimension should be less than " - "%d but we got %d", - std::numeric_limits::max(), - sequences_dims[sequences_dims.size() - 1])); + "The size of sorted_sequence %d exceed the maximum limit " + "d%. Because the size of sorted_sequence should be less than the " + "output maximum value for int32 bit. Please set appropriate " + "sorted_sequence to meet this requirement!", + sequences_dims[sequences_dims.size() - 1], + std::numeric_limits::max())); } ctx->SetOutputDim("Out", values_dims); @@ -89,15 +93,14 @@ class SearchSortedOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("SortedSequence", - "(Tensor), N-D or 1-D tensor, containing monotonically increasing " - "sequence on the innermost dimension."); - AddInput( - "Values", - "(Tensor), N-D tensor or a Scalar containing the search value(s)."); + "(Tensor), N-D or 1-D tensor, The value of the tensor" + "monotonically increases in the innermost dimension."); + AddInput("Values", + "(Tensor or Scalar), N-D tensor or a Scalar given values."); AddOutput("Out", "(Tensor), The output tensor of searchsorted op."); AddAttr("out_int32", - "the output tensor is int64_t type if False and int(32bit " - "normally) type if True.") + "the output tensor is int64 type if False and int32" + "normally type if True.") .SetDefault(false); AddAttr( "right", @@ -107,7 +110,7 @@ class SearchSortedOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Searchsorted Operator. - This operator is used to find the indices of the value from the innermost dimension of sorted_sequence + This operator is used to find the index of the given value from the innermost dimension of sorted_sequence )DOC"); } diff --git a/paddle/fluid/operators/searchsorted_op.h b/paddle/fluid/operators/searchsorted_op.h index b503e1f92af947..7a31044401247f 100644 --- a/paddle/fluid/operators/searchsorted_op.h +++ b/paddle/fluid/operators/searchsorted_op.h @@ -134,9 +134,10 @@ static void VisitDataType(framework::proto::VarType::Type type, visitor.template apply(); } else { PADDLE_THROW(platform::errors::InvalidArgument( - "The given values datatype of searchsorted operators must be float32, " - "float64, int32 or int64, but the recieved values datatype of " - "searchsorted operators is %s", + "The recieved values data type %s can not meet " + "inputrequirements.Because the given values data type of searchsorted " + "operators must be float32, float64, int32 or int64. Please input " + "appropriate sorted_sequence again!", framework::DataTypeToString(type))); } } diff --git a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py index 13434bdd4f4d8d..f595d06d5bce72 100644 --- a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py +++ b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py @@ -79,6 +79,14 @@ def init_test_case(self): self.side = "right" +class TestSearchSortedOp5(TestSearchSorted): + def init_test_case(self): + self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("float64") + self.values = np.array([[np.inf, np.inf, np.inf], + [np.nan, np.nan, np.nan]]).astype("float64") + self.side = "right" + + class TestSearchSortedAPI(unittest.TestCase): def init_test_case(self): self.sorted_sequence = np.array([2, 4, 6, 8, 10]).astype("float64") diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index c3670c6a55628e..cf03b35c80a7f6 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -773,67 +773,45 @@ def searchsorted(sorted_sequence, right=False, name=None): """ - This OP is used to find the indices from the innermost dimension of `sorted_sequence`. If the correspoding values in `values` - were inserted before the indices, the order of the corresponding innermost dimension within :`sorted_sequence` would - be preserved. Return a new tensor with the same size as `values`. If `right` is False (default),then the left boundary of - `sorted_sequence` is closed. More formally, the returned index satisfies the following rules: - * - `sorted_sequence` - - `right` - - returned index satisfies - * - 1-D - - False - - ``sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]`` - * - 1-D - - True - - ``sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]`` - * - N-D - - False - - ``sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]`` - * - N-D - - True - - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` + This OP is used to find the index of the corresponding `sorted_sequence` in the innermost dimension based on the given `values`. Args: - sorted_sequence(Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the innermost dimension. The data type can be int32, int64, float32, float64. - values(Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). The data type can be int32, int64, float32, float64. - out_int32(bool, optional): indicate the output data type. The default value is False. - right(bool, optional): if False, return the first suitable location that is found. If True, return the - last such index. If no suitable index found, return 0 for non-numerical value - (eg. nan, inf). In other words, if False, gets the lower bound index for each value in `values` on the corresponding - innermost dimension of the `sorted_sequence`. If True, gets the upper bound index instead. The default value is False. - name(str, optional):The default value is None. Normally there is no need for user to set this property. For more information, please - refer to :ref:`api_guide_Name`. - + sorted_sequence(Tensor): An input N-D or 1-D tensor with type int32, int64, float32, float64. The value of the tensor monotonically increases in the innermost dimension. + values(Tensor or Scalar): An input N-D tensor or a Scalar value with type int32, int64, float32, float64. + out_int32(bool, optional): Data type of the output tensor which can be int32, int64. The default value is False, and it indicates that the output data type is int64. + right(bool, optional): Find the upper and lower bounds of the sorted_sequence range in the innermost dimension based on the given `values`. If the sorted_sequence value is nan or inf, return the size of the innermost dimension + The default value is False and it shows the lower bounds. + name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + Returns: - output (Tensor): return the indices from the innermost dimension of sorted_sequence. The output tensor is the same size as values. - + An N-D Tensor the same sizes of the `values`, return the tensor of `int32` if set :attr:`out_int32` is True, otherwise return the tensor of `int64`. + Examples: .. code-block:: python import paddle - sorted_sequence = paddle.to_tensor([[1, 3, 5, 7, 9,11], - [2, 4, 6, 8, 10, 12]],dtype = 'int32') - values = paddle.to_tensor([[3,6,9],[3,6,9]],dtype = 'int32') - out1 = paddle.searchsorted(sorted_sequence,values) + sorted_sequence = paddle.to_tensor([[1, 3, 5, 7, 9, 11], + [2, 4, 6, 8, 10, 12]], dtype='int32') + values = paddle.to_tensor([[3, 6, 9, 10], [3, 6, 9, 10]], dtype='int32') + out1 = paddle.searchsorted(sorted_sequence, values) print(out1) - # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, - # [[1, 3, 4], - # [1, 2, 4]]) - - out2 = paddle.searchsorted(sorted_sequence,values,right=True) + # Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [[1, 3, 4, 5], + # [1, 2, 4, 4]]) + out2 = paddle.searchsorted(sorted_sequence, values, right=True) print(out2) - # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, - # [[2, 3, 5], - # [1, 3, 4]]) - - sorted_sequence_1d = paddle.to_tensor([1,3,5,7,9]) - out3 = paddle.searchsorted(sorted_sequence_1d,values) + # Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [[2, 3, 5, 5], + # [1, 3, 4, 5]]) + sorted_sequence_1d = paddle.to_tensor([1, 3, 5, 7, 9, 11, 13]) + out3 = paddle.searchsorted(sorted_sequence_1d, values) print(out3) - # Tensor(shape=[2, 3], dtype=int64, place=CUDAPlace(0), stop_gradient=True, - # [[1, 3, 4], - # [1, 3, 4]]) + # Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [[1, 3, 4, 5], + # [1, 3, 4, 5]]) """ if in_dygraph_mode(): From 753e3e17150dda7f90813d4bf012613101903ae6 Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Thu, 2 Sep 2021 05:49:50 +0000 Subject: [PATCH 11/14] fix error newline in constant --- paddle/fluid/operators/searchsorted_op.cc | 4 ++-- paddle/fluid/operators/searchsorted_op.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/searchsorted_op.cc b/paddle/fluid/operators/searchsorted_op.cc index 8249b28c73247b..7d72c7de36cdb6 100644 --- a/paddle/fluid/operators/searchsorted_op.cc +++ b/paddle/fluid/operators/searchsorted_op.cc @@ -60,7 +60,7 @@ class SearchSortedOp : public framework::OperatorWithKernel { "dimensions of sorted_sequence tensor " "and input values tensor must " "match. Please input appropriate sorted_sequence and values " - "again!", + "again! ", sequences_dims, values_dims)); } @@ -72,7 +72,7 @@ class SearchSortedOp : public framework::OperatorWithKernel { "The size of sorted_sequence %d exceed the maximum limit " "d%. Because the size of sorted_sequence should be less than the " "output maximum value for int32 bit. Please set appropriate " - "sorted_sequence to meet this requirement!", + "sorted_sequence to meet this requirement! ", sequences_dims[sequences_dims.size() - 1], std::numeric_limits::max())); } diff --git a/paddle/fluid/operators/searchsorted_op.h b/paddle/fluid/operators/searchsorted_op.h index 7a31044401247f..6528259ef92de9 100644 --- a/paddle/fluid/operators/searchsorted_op.h +++ b/paddle/fluid/operators/searchsorted_op.h @@ -137,7 +137,7 @@ static void VisitDataType(framework::proto::VarType::Type type, "The recieved values data type %s can not meet " "inputrequirements.Because the given values data type of searchsorted " "operators must be float32, float64, int32 or int64. Please input " - "appropriate sorted_sequence again!", + "appropriate sorted_sequence again! ", framework::DataTypeToString(type))); } } From 0fcc53d75f73394677de98802ff212a072f9fba5 Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Tue, 7 Sep 2021 13:41:32 +0000 Subject: [PATCH 12/14] modify doc after mentor review --- python/paddle/tensor/search.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index cf03b35c80a7f6..bf1e636220ed8b 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -776,13 +776,12 @@ def searchsorted(sorted_sequence, This OP is used to find the index of the corresponding `sorted_sequence` in the innermost dimension based on the given `values`. Args: - sorted_sequence(Tensor): An input N-D or 1-D tensor with type int32, int64, float32, float64. The value of the tensor monotonically increases in the innermost dimension. - values(Tensor or Scalar): An input N-D tensor or a Scalar value with type int32, int64, float32, float64. + sorted_sequence(Tensor): An input N-D tensor with type int32, int64, float32, float64. The value of the tensor monotonically increases in the innermost dimension. + values(Tensor): An input N-D tensor value with type int32, int64, float32, float64. out_int32(bool, optional): Data type of the output tensor which can be int32, int64. The default value is False, and it indicates that the output data type is int64. - right(bool, optional): Find the upper and lower bounds of the sorted_sequence range in the innermost dimension based on the given `values`. If the sorted_sequence value is nan or inf, return the size of the innermost dimension + right(bool, optional): Find the upper or lower bounds of the sorted_sequence range in the innermost dimension based on the given `values`. If the value of the sorted_sequence is nan or inf, return the size of the innermost dimension. The default value is False and it shows the lower bounds. - name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please - refer to :ref:`api_guide_Name`. + name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: An N-D Tensor the same sizes of the `values`, return the tensor of `int32` if set :attr:`out_int32` is True, otherwise return the tensor of `int64`. @@ -812,6 +811,7 @@ def searchsorted(sorted_sequence, # Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True, # [[1, 3, 4, 5], # [1, 3, 4, 5]]) + """ if in_dygraph_mode(): From 33c599648d81bbcc8dd74c67f7d88fa4630732c8 Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Thu, 9 Sep 2021 08:21:18 +0000 Subject: [PATCH 13/14] modify __all__ and doc --- paddle/fluid/operators/searchsorted_op.cc | 9 ++++----- python/paddle/__init__.py | 1 + 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/searchsorted_op.cc b/paddle/fluid/operators/searchsorted_op.cc index 7d72c7de36cdb6..4f1703f0d12420 100644 --- a/paddle/fluid/operators/searchsorted_op.cc +++ b/paddle/fluid/operators/searchsorted_op.cc @@ -93,14 +93,13 @@ class SearchSortedOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("SortedSequence", - "(Tensor), N-D or 1-D tensor, The value of the tensor" + "(Tensor), N-D tensor, The value of the tensor" "monotonically increases in the innermost dimension."); - AddInput("Values", - "(Tensor or Scalar), N-D tensor or a Scalar given values."); + AddInput("Values", "(Tensor), N-D tensor given values."); AddOutput("Out", "(Tensor), The output tensor of searchsorted op."); AddAttr("out_int32", - "the output tensor is int64 type if False and int32" - "normally type if True.") + "the output tensor is int64 type if False and On the" + "contrary for int32") .SetDefault(false); AddAttr( "right", diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ba3ee701885c4a..0fe16c19f80cd1 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -356,6 +356,7 @@ 'summary', 'flops', 'sort', + 'searchsorted', 'split', 'logical_and', 'full_like', From ed2173ddc561136704573977761e51cb0e4f1dcd Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Thu, 9 Sep 2021 11:52:26 +0000 Subject: [PATCH 14/14] modify doc --- paddle/fluid/operators/searchsorted_op.cc | 25 ++++++++++------------- paddle/fluid/operators/searchsorted_op.h | 8 ++++---- python/paddle/tensor/search.py | 4 ++-- 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/operators/searchsorted_op.cc b/paddle/fluid/operators/searchsorted_op.cc index 4f1703f0d12420..bbd5b9c4e7db91 100644 --- a/paddle/fluid/operators/searchsorted_op.cc +++ b/paddle/fluid/operators/searchsorted_op.cc @@ -53,14 +53,11 @@ class SearchSortedOp : public framework::OperatorWithKernel { SearchsortedDimsMatchedBeforeLastDim(sequences_dims, values_dims), true, platform::errors::Unavailable( - "The dimensions of sorted_sequence tensor ( %s ) and " - "values tensor ( %s ) can not match. Because the input " - "sorted_sequence tensor must be " - "1 dimension or the first N-1 " - "dimensions of sorted_sequence tensor " - "and input values tensor must " - "match. Please input appropriate sorted_sequence and values " - "again! ", + "The dimensions of sorted_sequence tensor ( %s ) and values " + "tensor ( %s ) can not match. Because the input sorted_sequence " + "tensor must be 1 dimension or the first N-1 dimensions of " + "sorted_sequence tensor and input values tensor must match. " + "Please input appropriate sorted_sequence and values again! ", sequences_dims, values_dims)); } @@ -69,8 +66,8 @@ class SearchSortedOp : public framework::OperatorWithKernel { sequences_dims[sequences_dims.size() - 1], std::numeric_limits::max(), platform::errors::Unavailable( - "The size of sorted_sequence %d exceed the maximum limit " - "d%. Because the size of sorted_sequence should be less than the " + "The size of sorted_sequence %d exceed the maximum limit d%. " + "Because the size of sorted_sequence should be less than the " "output maximum value for int32 bit. Please set appropriate " "sorted_sequence to meet this requirement! ", sequences_dims[sequences_dims.size() - 1], @@ -93,12 +90,12 @@ class SearchSortedOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("SortedSequence", - "(Tensor), N-D tensor, The value of the tensor" + "(Tensor), N-D or 1-D tensor, The value of the tensor" "monotonically increases in the innermost dimension."); AddInput("Values", "(Tensor), N-D tensor given values."); AddOutput("Out", "(Tensor), The output tensor of searchsorted op."); AddAttr("out_int32", - "the output tensor is int64 type if False and On the" + "the output tensor is int64 type if False and on the" "contrary for int32") .SetDefault(false); AddAttr( @@ -109,8 +106,8 @@ class SearchSortedOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Searchsorted Operator. - This operator is used to find the index of the given value from the innermost dimension of sorted_sequence - + This OP is used to find the index of the corresponding sorted_sequence in the innermost dimension based on the given values. + )DOC"); } }; diff --git a/paddle/fluid/operators/searchsorted_op.h b/paddle/fluid/operators/searchsorted_op.h index 6528259ef92de9..5ae0e79907bf99 100644 --- a/paddle/fluid/operators/searchsorted_op.h +++ b/paddle/fluid/operators/searchsorted_op.h @@ -134,10 +134,10 @@ static void VisitDataType(framework::proto::VarType::Type type, visitor.template apply(); } else { PADDLE_THROW(platform::errors::InvalidArgument( - "The recieved values data type %s can not meet " - "inputrequirements.Because the given values data type of searchsorted " - "operators must be float32, float64, int32 or int64. Please input " - "appropriate sorted_sequence again! ", + "The recieved values data type %s can not meet input requirements. " + "Because the given values data type of searchsorted operators must be " + "float32, float64, int32 or int64. Please input appropriate " + "sorted_sequence again! ", framework::DataTypeToString(type))); } } diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index bf1e636220ed8b..55c6f8ec67ca8c 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -776,7 +776,7 @@ def searchsorted(sorted_sequence, This OP is used to find the index of the corresponding `sorted_sequence` in the innermost dimension based on the given `values`. Args: - sorted_sequence(Tensor): An input N-D tensor with type int32, int64, float32, float64. The value of the tensor monotonically increases in the innermost dimension. + sorted_sequence(Tensor): An input N-D or 1-D tensor with type int32, int64, float32, float64. The value of the tensor monotonically increases in the innermost dimension. values(Tensor): An input N-D tensor value with type int32, int64, float32, float64. out_int32(bool, optional): Data type of the output tensor which can be int32, int64. The default value is False, and it indicates that the output data type is int64. right(bool, optional): Find the upper or lower bounds of the sorted_sequence range in the innermost dimension based on the given `values`. If the value of the sorted_sequence is nan or inf, return the size of the innermost dimension. @@ -784,7 +784,7 @@ def searchsorted(sorted_sequence, name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: - An N-D Tensor the same sizes of the `values`, return the tensor of `int32` if set :attr:`out_int32` is True, otherwise return the tensor of `int64`. + Tensor(the same sizes of the `values`), return the tensor of int32 if set :attr:`out_int32` is True, otherwise return the tensor of int64. Examples: