From 6d880468d4211f4c08b932cff7ca4eff103d54a0 Mon Sep 17 00:00:00 2001 From: LuJiafeng Date: Tue, 13 Jul 2021 15:00:12 +0000 Subject: [PATCH 1/7] logical ops support int8, int16, int32, int64, float, double --- .../fluid/operators/controlflow/logical_op.cc | 23 +++---- .../fluid/operators/controlflow/logical_op.cu | 31 +++++---- .../fluid/operators/controlflow/logical_op.h | 43 +++++++++--- python/paddle/fluid/layers/nn.py | 31 +++++---- .../fluid/tests/unittests/test_logical_op.py | 67 +++++++++++-------- 5 files changed, 117 insertions(+), 78 deletions(-) diff --git a/paddle/fluid/operators/controlflow/logical_op.cc b/paddle/fluid/operators/controlflow/logical_op.cc index fb8cde70f5324f..285b17d4995dbc 100644 --- a/paddle/fluid/operators/controlflow/logical_op.cc +++ b/paddle/fluid/operators/controlflow/logical_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -26,15 +23,16 @@ class BinaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker { void Make() override { OpComment comment; AddInput("X", string::Sprintf("Left hand operand of %s operator. Must be " - "a Variable of type bool.", + "a Variable of type being one of bool, int8, " + "int16, int32, int64, float32, float64.", comment.type)); AddInput("Y", string::Sprintf("Right hand operand of %s operator. Must be " - "a Variable of type bool.", + "a Variable of type being one of bool, int8, " + "int16, int32, int64, float32, float64.", comment.type)); AddOutput("Out", string::Sprintf("n-dim bool Variable")); AddComment(string::Sprintf(R"DOC(%s Operator - -It operates element-wise on X and Y, and returns the Out. X, Y and Out are N-dim boolean LoDTensor or Tensor. +It operates element-wise on X and Y, and returns the Out. X, Y and Out are N-dim LoDTensor or Tensor. Each element of Out is calculated by %s )DOC", comment.type, comment.equation)); @@ -46,13 +44,14 @@ class UnaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { OpComment comment; - AddInput("X", string::Sprintf("Operand of %s operator. Must be " - "a LoDTensor or Tensor of type bool.", - comment.type)); + AddInput("X", + string::Sprintf("Operand of %s operator. Must be " + "a LoDTensor or Tensor of type being one of bool, " + "int8, int16, int32, int64, float32, float64.", + comment.type)); AddOutput("Out", string::Sprintf("n-dim bool LoDTensor or Tensor.")); AddComment(string::Sprintf(R"DOC(%s Operator - -It operates element-wise on X, and returns the Out. X and Out are N-dim boolean LoDTensor or Tensor. +It operates element-wise on X, and returns the Out. X and Out are N-dim LoDTensor or Tensor. Each element of Out is calculated by %s )DOC", comment.type, comment.equation)); diff --git a/paddle/fluid/operators/controlflow/logical_op.cu b/paddle/fluid/operators/controlflow/logical_op.cu index 6cbcd516e08264..301b4c4149fad3 100644 --- a/paddle/fluid/operators/controlflow/logical_op.cu +++ b/paddle/fluid/operators/controlflow/logical_op.cu @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -21,13 +18,13 @@ namespace plat = paddle::platform; namespace paddle { namespace operators { -#define LOGICAL_BINARY_FUNCTOR(func_name, op) \ - template \ - struct func_name { \ - using ELEMENT_TYPE = T; \ - HOSTDEVICE bool operator()(const T* args) const { \ - return args[0] op args[1]; \ - } \ +#define LOGICAL_BINARY_FUNCTOR(func_name, op) \ + template \ + struct func_name { \ + using ELEMENT_TYPE = T; \ + HOSTDEVICE bool operator()(const T* args) const { \ + return static_cast(args[0]) op static_cast(args[1]); \ + } \ }; LOGICAL_BINARY_FUNCTOR(CudaOrFunctor, ||) @@ -68,10 +65,16 @@ class BinaryLogicalOpKernel } // namespace operators } // namespace paddle -#define REGISTER_LOGICAL_CUDA_KERNEL(op_name, func) \ - REGISTER_OP_CUDA_KERNEL( \ - op_name, \ - ops::BinaryLogicalOpKernel>); +#define REGISTER_LOGICAL_CUDA_KERNEL(op_name, func) \ + REGISTER_OP_CUDA_KERNEL( \ + op_name, \ + ops::BinaryLogicalOpKernel>, \ + ops::BinaryLogicalOpKernel>, \ + ops::BinaryLogicalOpKernel>, \ + ops::BinaryLogicalOpKernel>, \ + ops::BinaryLogicalOpKernel>, \ + ops::BinaryLogicalOpKernel>, \ + ops::BinaryLogicalOpKernel>); REGISTER_LOGICAL_CUDA_KERNEL(logical_or, CudaOrFunctor) REGISTER_LOGICAL_CUDA_KERNEL(logical_and, CudaAndFunctor) diff --git a/paddle/fluid/operators/controlflow/logical_op.h b/paddle/fluid/operators/controlflow/logical_op.h index 2c39201a426a25..92fe0a10cb907c 100644 --- a/paddle/fluid/operators/controlflow/logical_op.h +++ b/paddle/fluid/operators/controlflow/logical_op.h @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -82,12 +79,36 @@ class UnaryLogicalOpKernel } // namespace operators } // namespace paddle -#define REGISTER_BINARY_LOGICAL_KERNEL(op_type, dev, functor) \ - REGISTER_OP_##dev##_KERNEL( \ - op_type, ::paddle::operators::BinaryLogicalOpKernel< \ - ::paddle::platform::dev##DeviceContext, functor>); +#define REGISTER_BINARY_LOGICAL_KERNEL(op_type, dev, functor) \ + REGISTER_OP_##dev##_KERNEL( \ + op_type, ::paddle::operators::BinaryLogicalOpKernel< \ + ::paddle::platform::dev##DeviceContext, functor>, \ + ::paddle::operators::BinaryLogicalOpKernel< \ + ::paddle::platform::dev##DeviceContext, functor>, \ + ::paddle::operators::BinaryLogicalOpKernel< \ + ::paddle::platform::dev##DeviceContext, functor>, \ + ::paddle::operators::BinaryLogicalOpKernel< \ + ::paddle::platform::dev##DeviceContext, functor>, \ + ::paddle::operators::BinaryLogicalOpKernel< \ + ::paddle::platform::dev##DeviceContext, functor>, \ + ::paddle::operators::BinaryLogicalOpKernel< \ + ::paddle::platform::dev##DeviceContext, functor>, \ + ::paddle::operators::BinaryLogicalOpKernel< \ + ::paddle::platform::dev##DeviceContext, functor>); -#define REGISTER_UNARY_LOGICAL_KERNEL(op_type, dev, functor) \ - REGISTER_OP_##dev##_KERNEL( \ - op_type, ::paddle::operators::UnaryLogicalOpKernel< \ - ::paddle::platform::dev##DeviceContext, functor>); +#define REGISTER_UNARY_LOGICAL_KERNEL(op_type, dev, functor) \ + REGISTER_OP_##dev##_KERNEL( \ + op_type, ::paddle::operators::UnaryLogicalOpKernel< \ + ::paddle::platform::dev##DeviceContext, functor>, \ + ::paddle::operators::UnaryLogicalOpKernel< \ + ::paddle::platform::dev##DeviceContext, functor>, \ + ::paddle::operators::UnaryLogicalOpKernel< \ + ::paddle::platform::dev##DeviceContext, functor>, \ + ::paddle::operators::UnaryLogicalOpKernel< \ + ::paddle::platform::dev##DeviceContext, functor>, \ + ::paddle::operators::UnaryLogicalOpKernel< \ + ::paddle::platform::dev##DeviceContext, functor>, \ + ::paddle::operators::UnaryLogicalOpKernel< \ + ::paddle::platform::dev##DeviceContext, functor>, \ + ::paddle::operators::UnaryLogicalOpKernel< \ + ::paddle::platform::dev##DeviceContext, functor>); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c1c97c3f7742b7..a5cab4b9e882cc 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -12152,17 +12152,22 @@ def _logical_op(op_name, x, y, out=None, name=None, binary_op=True): return op(x, y) else: return op(x) - - check_variable_and_dtype(x, "x", ["bool"], op_name) + check_variable_and_dtype(x, "x", [ + "bool", "int8", "int16", "int32", "int64", "float32", "float64" + ], op_name) if y is not None: - check_variable_and_dtype(y, "y", ["bool"], op_name) + check_variable_and_dtype(y, "y", [ + "bool", "int8", "int16", "int32", "int64", "float32", "float64" + ], op_name) if out is not None: check_type(out, "out", Variable, op_name) helper = LayerHelper(op_name, **locals()) - if binary_op: - assert x.dtype == y.dtype + if binary_op and x.dtype != y.dtype: + raise ValueError( + "(InvalidArgument) The DataType of %s Op's Variable must be consistent, but received %s and %s." + % (op_name, x.dtype, y.dtype)) if out is None: out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -12180,7 +12185,7 @@ def _logical_op(op_name, x, y, out=None, name=None, binary_op=True): def logical_and(x, y, out=None, name=None): r""" - ``logical_and`` operator computes element-wise logical AND on ``x`` and ``y``, and returns ``out``. ``x``, ``y`` and ``out`` are N-dim boolean ``Tensor``. + ``logical_and`` operator computes element-wise logical AND on ``x`` and ``y``, and returns ``out``. ``out`` is N-dim boolean ``Tensor``. Each element of ``out`` is calculated by .. math:: @@ -12191,8 +12196,8 @@ def logical_and(x, y, out=None, name=None): ``paddle.logical_and`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`. Args: - x (Tensor): the input tensor, it's data type should be bool. - y (Tensor): the input tensor, it's data type should be bool. + x (Tensor): the input tensor, it's data type should be one of bool, int8, int16, in32, in64, float32, float64. + y (Tensor): the input tensor, it's data type should be one of bool, int8, int16, in32, in64, float32, float64. out(Tensor): The ``Tensor`` that specifies the output of the operator, which can be any ``Tensor`` that has been created in the program. The default value is None, and a new ``Tensor`` will be created to save the output. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -12255,7 +12260,7 @@ def logical_or(x, y, out=None, name=None): def logical_xor(x, y, out=None, name=None): r""" - ``logical_xor`` operator computes element-wise logical XOR on ``x`` and ``y``, and returns ``out``. ``x``, ``y`` and ``out`` are N-dim boolean ``Tensor``. + ``logical_xor`` operator computes element-wise logical XOR on ``x`` and ``y``, and returns ``out``. ``out`` is N-dim boolean ``Tensor``. Each element of ``out`` is calculated by .. math:: @@ -12266,8 +12271,8 @@ def logical_xor(x, y, out=None, name=None): ``paddle.logical_xor`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`. Args: - x (Tensor): the input tensor, it's data type should be bool. - y (Tensor): the input tensor, it's data type should be bool. + x (Tensor): the input tensor, it's data type should be one of bool, int8, int16, in32, in64, float32, float64. + y (Tensor): the input tensor, it's data type should be one of bool, int8, int16, in32, in64, float32, float64. out(Tensor): The ``Tensor`` that specifies the output of the operator, which can be any ``Tensor`` that has been created in the program. The default value is None, and a new ``Tensor`` will be created to save the output. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -12295,7 +12300,7 @@ def logical_xor(x, y, out=None, name=None): def logical_not(x, out=None, name=None): """ - ``logical_not`` operator computes element-wise logical NOT on ``x``, and returns ``out``. ``x`` and ``out`` are N-dim boolean ``Variable``. + ``logical_not`` operator computes element-wise logical NOT on ``x``, and returns ``out``. ``out`` is N-dim boolean ``Variable``. Each element of ``out`` is calculated by .. math:: @@ -12303,7 +12308,7 @@ def logical_not(x, out=None, name=None): out = !x Args: - x(Tensor): Operand of logical_not operator. Must be a Tensor of type bool. + x(Tensor): Operand of logical_not operator. Must be a Tensor of type bool, int8, int16, in32, in64, float32, or float64. out(Tensor): The ``Tensor`` that specifies the output of the operator, which can be any ``Tensor`` that has been created in the program. The default value is None, and a new ``Tensor` will be created to save the output. name(str|None): The default value is None. Normally there is no need for users to set this property. For more information, please refer to :ref:`api_guide_Name`. diff --git a/python/paddle/fluid/tests/unittests/test_logical_op.py b/python/paddle/fluid/tests/unittests/test_logical_op.py index c8bb8c5b73f768..e77526bdb16bc9 100755 --- a/python/paddle/fluid/tests/unittests/test_logical_op.py +++ b/python/paddle/fluid/tests/unittests/test_logical_op.py @@ -21,6 +21,10 @@ import paddle.fluid as fluid from paddle.static import Program, program_guard +SUPPORTED_DTYPES = [ + bool, np.int8, np.int16, np.int32, np.int64, np.float32, np.float64 +] + TEST_META_OP_DATA = [{ 'op_str': 'logical_and', 'binary_op': True @@ -111,13 +115,13 @@ def run_static(x_np, y_np, op_str, use_gpu=False, binary_op=True): place = paddle.CUDAPlace(0) exe = fluid.Executor(place) with fluid.program_guard(main_program, startup_program): - x = paddle.static.data(name='x', shape=x_np.shape, dtype='bool') + x = paddle.static.data(name='x', shape=x_np.shape, dtype=x_np.dtype) op = getattr(paddle, op_str) feed_list = {'x': x_np} if not binary_op: res = op(x) else: - y = paddle.static.data(name='y', shape=y_np.shape, dtype='bool') + y = paddle.static.data(name='y', shape=y_np.shape, dtype=y_np.dtype) feed_list['y'] = y_np res = op(x, y) exe.run(startup_program) @@ -131,17 +135,20 @@ def run_dygraph(x_np, y_np, op_str, use_gpu=False, binary_op=True): place = paddle.CUDAPlace(0) paddle.disable_static(place) op = getattr(paddle, op_str) - x = paddle.to_tensor(x_np) + x = paddle.to_tensor(x_np, dtype=x_np.dtype) if not binary_op: dygraph_result = op(x) else: - y = paddle.to_tensor(y_np) + y = paddle.to_tensor(y_np, dtype=y_np.dtype) dygraph_result = op(x, y) return dygraph_result -def np_data_generator(np_shape, *args, **kwargs): - return np.random.choice(a=[True, False], size=np_shape).astype(bool) +def np_data_generator(np_shape, dtype, *args, **kwargs): + if dtype == bool: + return np.random.choice(a=[True, False], size=np_shape).astype(bool) + else: + return np.random.randn(*np_shape).astype(dtype) def test(unit_test, use_gpu=False, test_error=False): @@ -153,40 +160,46 @@ def test(unit_test, use_gpu=False, test_error=False): if test_error: META_DATA = dict(TEST_META_WRONG_SHAPE_DATA) for shape_data in META_DATA.values(): - meta_data['x_np'] = np_data_generator(shape_data['x_shape']) - meta_data['y_np'] = np_data_generator(shape_data['y_shape']) - if meta_data['binary_op'] and test_error: - # catch C++ Exception - unit_test.assertRaises(BaseException, run_static, **meta_data) - unit_test.assertRaises(BaseException, run_dygraph, **meta_data) - continue - static_result = run_static(**meta_data) - dygraph_result = run_dygraph(**meta_data) - if meta_data['binary_op']: - np_result = np_op(meta_data['x_np'], meta_data['y_np']) - else: - np_result = np_op(meta_data['x_np']) - unit_test.assertTrue((static_result == np_result).all()) - unit_test.assertTrue((dygraph_result.numpy() == np_result).all()) + for data_type in SUPPORTED_DTYPES: + meta_data['x_np'] = np_data_generator( + shape_data['x_shape'], dtype=data_type) + meta_data['y_np'] = np_data_generator( + shape_data['y_shape'], dtype=data_type) + if meta_data['binary_op'] and test_error: + # catch C++ Exception + unit_test.assertRaises(BaseException, run_static, + **meta_data) + unit_test.assertRaises(BaseException, run_dygraph, + **meta_data) + continue + static_result = run_static(**meta_data) + dygraph_result = run_dygraph(**meta_data) + if meta_data['binary_op']: + np_result = np_op(meta_data['x_np'], meta_data['y_np']) + else: + np_result = np_op(meta_data['x_np']) + unit_test.assertTrue((static_result == np_result).all()) + unit_test.assertTrue((dygraph_result.numpy() == np_result).all( + )) def test_type_error(unit_test, use_gpu, type_str_map): def check_type(op_str, x, y, binary_op): op = getattr(paddle, op_str) - error_type = TypeError + error_type = ValueError if isinstance(x, np.ndarray): x = paddle.to_tensor(x) y = paddle.to_tensor(y) error_type = BaseException if binary_op: - if type_str_map['x'] != 'bool' or type_str_map['y'] != 'bool': + if type_str_map['x'] != type_str_map['y']: unit_test.assertRaises(error_type, op, x=x, y=y) if not fluid.in_dygraph_mode(): + error_type = TypeError unit_test.assertRaises(error_type, op, x=x, y=y, out=1) else: - if type_str_map['x'] != 'bool': - unit_test.assertRaises(error_type, op, x=x) if not fluid.in_dygraph_mode(): + error_type = TypeError unit_test.assertRaises(error_type, op, x=x, out=1) place = paddle.CPUPlace() @@ -213,12 +226,10 @@ def check_type(op_str, x, y, binary_op): def type_map_factory(): - x_type_list = ['float32', 'float64', 'int32', 'int64', 'bool'] - y_type_list = ['float32', 'float64', 'int32', 'int64', 'bool'] return [{ 'x': x_type, 'y': y_type - } for x_type in x_type_list for y_type in y_type_list] + } for x_type in SUPPORTED_DTYPES for y_type in SUPPORTED_DTYPES] class TestCPU(unittest.TestCase): From fda2abf0394e3d6659197dabb78d70ce46af495e Mon Sep 17 00:00:00 2001 From: LuJiafeng Date: Wed, 14 Jul 2021 11:46:26 +0000 Subject: [PATCH 2/7] update docs of logical ops --- python/paddle/fluid/layers/nn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index a5cab4b9e882cc..971d7daf43c74c 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -12221,7 +12221,7 @@ def logical_and(x, y, out=None, name=None): def logical_or(x, y, out=None, name=None): """ - ``logical_or`` operator computes element-wise logical OR on ``x`` and ``y``, and returns ``out``. ``x``, ``y`` and ``out`` are N-dim boolean ``Tensor``. + ``logical_or`` operator computes element-wise logical OR on ``x`` and ``y``, and returns ``out``. ``out`` is N-dim boolean ``Tensor``. Each element of ``out`` is calculated by .. math:: @@ -12232,8 +12232,8 @@ def logical_or(x, y, out=None, name=None): ``paddle.logical_or`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`. Args: - x (Tensor): the input tensor, it's data type should be bool. - y (Tensor): the input tensor, it's data type should be bool. + x (Tensor): the input tensor, it's data type should be one of bool, int8, int16, in32, in64, float32, float64. + y (Tensor): the input tensor, it's data type should be one of bool, int8, int16, in32, in64, float32, float64. out(Tensor): The ``Variable`` that specifies the output of the operator, which can be any ``Tensor`` that has been created in the program. The default value is None, and a new ``Tensor`` will be created to save the output. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. From ac8b193108cec0991bffcf81a6a09750f36e683b Mon Sep 17 00:00:00 2001 From: LuJiafeng Date: Wed, 14 Jul 2021 12:35:29 +0000 Subject: [PATCH 3/7] fix npu and xpu logical ops --- paddle/fluid/operators/controlflow/logical_op_npu.cc | 8 +++++++- paddle/fluid/operators/controlflow/logicaland_op_xpu.cc | 8 +++++++- paddle/fluid/operators/controlflow/logicalnot_op_xpu.cc | 8 +++++++- paddle/fluid/operators/controlflow/logicalor_op_xpu.cc | 8 +++++++- 4 files changed, 28 insertions(+), 4 deletions(-) mode change 100755 => 100644 paddle/fluid/operators/controlflow/logicalnot_op_xpu.cc diff --git a/paddle/fluid/operators/controlflow/logical_op_npu.cc b/paddle/fluid/operators/controlflow/logical_op_npu.cc index b9807bfa53e1e1..cf9b6f7ecec18f 100644 --- a/paddle/fluid/operators/controlflow/logical_op_npu.cc +++ b/paddle/fluid/operators/controlflow/logical_op_npu.cc @@ -52,6 +52,12 @@ namespace ops = paddle::operators; REGISTER_OP_NPU_KERNEL( logical_not, - ops::LogicalNotNPUKernel); + ops::LogicalNotNPUKernel, + ops::LogicalNotNPUKernel, + ops::LogicalNotNPUKernel, + ops::LogicalNotNPUKernel, + ops::LogicalNotNPUKernel, + ops::LogicalNotNPUKernel, + ops::LogicalNotNPUKernel); #endif diff --git a/paddle/fluid/operators/controlflow/logicaland_op_xpu.cc b/paddle/fluid/operators/controlflow/logicaland_op_xpu.cc index 08927e66f25064..6248b6e0b06378 100644 --- a/paddle/fluid/operators/controlflow/logicaland_op_xpu.cc +++ b/paddle/fluid/operators/controlflow/logicaland_op_xpu.cc @@ -17,5 +17,11 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_XPU_KERNEL( logical_and, - ops::BinaryLogicalOpXPUKernel); + ops::BinaryLogicalOpXPUKernel, + ops::BinaryLogicalOpXPUKernel, + ops::BinaryLogicalOpXPUKernel, + ops::BinaryLogicalOpXPUKernel, + ops::BinaryLogicalOpXPUKernel, + ops::BinaryLogicalOpXPUKernel, + ops::BinaryLogicalOpXPUKernel); #endif diff --git a/paddle/fluid/operators/controlflow/logicalnot_op_xpu.cc b/paddle/fluid/operators/controlflow/logicalnot_op_xpu.cc old mode 100755 new mode 100644 index a8cef52ace2c60..be857db8aa9669 --- a/paddle/fluid/operators/controlflow/logicalnot_op_xpu.cc +++ b/paddle/fluid/operators/controlflow/logicalnot_op_xpu.cc @@ -15,5 +15,11 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU #include "paddle/fluid/operators/controlflow/logical_op_xpu.h" namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL(logicalnot, ops::UnaryLogicalOpXPUKernel); +REGISTER_OP_XPU_KERNEL(logicalnot, ops::UnaryLogicalOpXPUKernel, + ops::UnaryLogicalOpXPUKernel, + ops::UnaryLogicalOpXPUKernel, + ops::UnaryLogicalOpXPUKernel, + ops::UnaryLogicalOpXPUKernel, + ops::UnaryLogicalOpXPUKernel, + ops::UnaryLogicalOpXPUKernel); #endif diff --git a/paddle/fluid/operators/controlflow/logicalor_op_xpu.cc b/paddle/fluid/operators/controlflow/logicalor_op_xpu.cc index e99c2f1a181040..126596841a29f8 100644 --- a/paddle/fluid/operators/controlflow/logicalor_op_xpu.cc +++ b/paddle/fluid/operators/controlflow/logicalor_op_xpu.cc @@ -18,5 +18,11 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_XPU_KERNEL( logical_or, - ops::BinaryLogicalOpXPUKernel); + ops::BinaryLogicalOpXPUKernel, + ops::BinaryLogicalOpXPUKernel, + ops::BinaryLogicalOpXPUKernel, + ops::BinaryLogicalOpXPUKernel, + ops::BinaryLogicalOpXPUKernel, + ops::BinaryLogicalOpXPUKernel, + ops::BinaryLogicalOpXPUKernel); #endif From c130cf2ab2aec359d38dcc4b9c2c54a89072c120 Mon Sep 17 00:00:00 2001 From: LuJiafeng Date: Thu, 15 Jul 2021 03:22:19 +0000 Subject: [PATCH 4/7] fix npu and xpu logical ops --- .../operators/controlflow/logical_op_npu.cc | 86 ++++++++++++++----- 1 file changed, 65 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/operators/controlflow/logical_op_npu.cc b/paddle/fluid/operators/controlflow/logical_op_npu.cc index cf9b6f7ecec18f..babdb2257ee3ca 100644 --- a/paddle/fluid/operators/controlflow/logical_op_npu.cc +++ b/paddle/fluid/operators/controlflow/logical_op_npu.cc @@ -1,21 +1,14 @@ /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_ASCEND_CL -#include -#include - #include "paddle/fluid/operators/controlflow/logical_op.h" #include "paddle/fluid/operators/npu_op_runner.h" @@ -29,12 +22,9 @@ class LogicalNotNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); - auto* out = ctx.Output("Out"); - auto place = ctx.GetPlace(); - - out->mutable_data(place); + out->mutable_data(ctx.GetPlace()); auto stream = ctx.template device_context() @@ -45,19 +35,73 @@ class LogicalNotNPUKernel : public framework::OpKernel { } }; +template +class LogicalOrNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Output("Out"); + + out->mutable_data(ctx.GetPlace()); + + auto stream = + ctx.template device_context() + .stream(); + + const auto& runner = NpuOpRunner("LogicalOr", {*x, *y}, {*out}, {}); + runner.Run(stream); + } +}; + +template +class LogicalAndPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Output("Out"); + + out->mutable_data(ctx.GetPlace()); + + auto stream = + ctx.template device_context() + .stream(); + + const auto& runner = NpuOpRunner("LogicalAnd", {*x, *y}, {*out}, {}); + runner.Run(stream); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_NPU_KERNEL( - logical_not, - ops::LogicalNotNPUKernel, - ops::LogicalNotNPUKernel, - ops::LogicalNotNPUKernel, - ops::LogicalNotNPUKernel, - ops::LogicalNotNPUKernel, - ops::LogicalNotNPUKernel, - ops::LogicalNotNPUKernel); - -#endif + logical_not, ops::LogicalNotNPUKernel, + ops::LogicalNotNPUKernel, + ops::LogicalNotNPUKernel, + ops::LogicalNotNPUKernel, + ops::LogicalNotNPUKernel, + ops::LogicalNotNPUKernel, + ops::LogicalNotNPUKernel); + +REGISTER_OP_NPU_KERNEL(logical_or, + ops::LogicalOrNPUKernel, + ops::LogicalOrNPUKernel, + ops::LogicalOrNPUKernel, + ops::LogicalOrNPUKernel, + ops::LogicalOrNPUKernel, + ops::LogicalOrNPUKernel, + ops::LogicalOrNPUKernel); + +REGISTER_OP_NPU_KERNEL(logical_and, + ops::LogicalAndPUKernel, + ops::LogicalAndPUKernel, + ops::LogicalAndPUKernel, + ops::LogicalAndPUKernel, + ops::LogicalAndPUKernel, + ops::LogicalAndPUKernel, + ops::LogicalAndPUKernel); From 611cb662786f0abc69ed7077bb1f2a588e9345c5 Mon Sep 17 00:00:00 2001 From: LuJiafeng Date: Thu, 15 Jul 2021 08:49:49 +0000 Subject: [PATCH 5/7] fix bug in xpu logical op code --- paddle/fluid/operators/controlflow/logical_op_xpu.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/controlflow/logical_op_xpu.h b/paddle/fluid/operators/controlflow/logical_op_xpu.h index 9d46ad8c0447ff..aef6ae27a31945 100644 --- a/paddle/fluid/operators/controlflow/logical_op_xpu.h +++ b/paddle/fluid/operators/controlflow/logical_op_xpu.h @@ -45,7 +45,7 @@ class BinaryLogicalOpXPUKernel : public framework::OpKernel { auto* x = context.Input("X"); auto* y = context.Input("Y"); auto* out = context.Output("Out"); - T* out_ptr = out->mutable_data(context.GetPlace()); + bool* out_ptr = out->mutable_data(context.GetPlace()); const T* x_ptr = x->data(); const T* y_ptr = y->data(); auto& dev_ctx = @@ -153,7 +153,7 @@ class UnaryLogicalOpXPUKernel : public framework::OpKernel { if (x->numel() == 0) { return; } - out->mutable_data(context.GetPlace()); + out->mutable_data(context.GetPlace()); auto& dev_ctx = context.template device_context(); int ret = xpu::logical_not(dev_ctx.x_context(), x->data(), From d6c1cb2a86ae6bacc782d10415e60a33635b4c5b Mon Sep 17 00:00:00 2001 From: LuJiafeng Date: Thu, 15 Jul 2021 11:14:12 +0000 Subject: [PATCH 6/7] update test_logical_op_npu and test_logical_op_xpu --- .../unittests/npu/test_logical_op_npu.py | 64 +++++++++++-------- .../unittests/xpu/test_logical_op_xpu.py | 63 ++++++++++-------- 2 files changed, 73 insertions(+), 54 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/npu/test_logical_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_logical_op_npu.py index 6d1327f068a528..71e7c483125988 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_logical_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_logical_op_npu.py @@ -23,6 +23,10 @@ import paddle.fluid as fluid from paddle.static import Program, program_guard +SUPPORTED_DTYPES = [ + bool, np.int8, np.int16, np.int32, np.int64, np.float32, np.float64 +] + TEST_META_OP_DATA = [{ 'op_str': 'logical_and', 'binary_op': True @@ -110,13 +114,13 @@ def run_static(x_np, y_np, op_str, use_npu=False, binary_op=True): place = paddle.NPUPlace(0) exe = fluid.Executor(place) with fluid.program_guard(main_program, startup_program): - x = paddle.static.data(name='x', shape=x_np.shape, dtype='bool') + x = paddle.static.data(name='x', shape=x_np.shape, dtype=x_np.dtype) op = getattr(paddle, op_str) feed_list = {'x': x_np} if not binary_op: res = op(x) else: - y = paddle.static.data(name='y', shape=y_np.shape, dtype='bool') + y = paddle.static.data(name='y', shape=y_np.shape, dtype=y_np.dtype) feed_list['y'] = y_np res = op(x, y) exe.run(startup_program) @@ -130,17 +134,20 @@ def run_dygraph(x_np, y_np, op_str, use_npu=False, binary_op=True): place = paddle.NPUPlace(0) paddle.disable_static(place) op = getattr(paddle, op_str) - x = paddle.to_tensor(x_np) + x = paddle.to_tensor(x_np, dtype=x_np.dtype) if not binary_op: dygraph_result = op(x) else: - y = paddle.to_tensor(y_np) + y = paddle.to_tensor(y_np, dtype=y_np.dtype) dygraph_result = op(x, y) return dygraph_result -def np_data_generator(np_shape, *args, **kwargs): - return np.random.choice(a=[True, False], size=np_shape).astype(bool) +def np_data_generator(np_shape, dtype, *args, **kwargs): + if dtype == bool: + return np.random.choice(a=[True, False], size=np_shape).astype(bool) + else: + return np.random.randn(*np_shape).astype(dtype) def test(unit_test, use_npu=False, test_error=False): @@ -152,21 +159,27 @@ def test(unit_test, use_npu=False, test_error=False): if test_error: META_DATA = dict(TEST_META_WRONG_SHAPE_DATA) for shape_data in META_DATA.values(): - meta_data['x_np'] = np_data_generator(shape_data['x_shape']) - meta_data['y_np'] = np_data_generator(shape_data['y_shape']) - if meta_data['binary_op'] and test_error: - # catch C++ Exception - unit_test.assertRaises(BaseException, run_static, **meta_data) - unit_test.assertRaises(BaseException, run_dygraph, **meta_data) - continue - static_result = run_static(**meta_data) - dygraph_result = run_dygraph(**meta_data) - if meta_data['binary_op']: - np_result = np_op(meta_data['x_np'], meta_data['y_np']) - else: - np_result = np_op(meta_data['x_np']) - unit_test.assertTrue((static_result == np_result).all()) - unit_test.assertTrue((dygraph_result.numpy() == np_result).all()) + for data_type in SUPPORTED_DTYPES: + meta_data['x_np'] = np_data_generator( + shape_data['x_shape'], dtype=data_type) + meta_data['y_np'] = np_data_generator( + shape_data['y_shape'], dtype=data_type) + if meta_data['binary_op'] and test_error: + # catch C++ Exception + unit_test.assertRaises(BaseException, run_static, + **meta_data) + unit_test.assertRaises(BaseException, run_dygraph, + **meta_data) + continue + static_result = run_static(**meta_data) + dygraph_result = run_dygraph(**meta_data) + if meta_data['binary_op']: + np_result = np_op(meta_data['x_np'], meta_data['y_np']) + else: + np_result = np_op(meta_data['x_np']) + unit_test.assertTrue((static_result == np_result).all()) + unit_test.assertTrue((dygraph_result.numpy() == np_result).all( + )) def test_type_error(unit_test, use_npu, type_str_map): @@ -178,13 +191,12 @@ def check_type(op_str, x, y, binary_op): y = paddle.to_tensor(y) error_type = BaseException if binary_op: - if type_str_map['x'] != 'bool' or type_str_map['y'] != 'bool': + if type_str_map['x'] != type_str_map['y']: unit_test.assertRaises(error_type, op, x=x, y=y) if not fluid.in_dygraph_mode(): + error_type = TypeError unit_test.assertRaises(error_type, op, x=x, y=y, out=1) else: - if type_str_map['x'] != 'bool': - unit_test.assertRaises(error_type, op, x=x) if not fluid.in_dygraph_mode(): unit_test.assertRaises(error_type, op, x=x, out=1) @@ -212,12 +224,10 @@ def check_type(op_str, x, y, binary_op): def type_map_factory(): - x_type_list = ['float32', 'float64', 'int32', 'int64', 'bool'] - y_type_list = ['float32', 'float64', 'int32', 'int64', 'bool'] return [{ 'x': x_type, 'y': y_type - } for x_type in x_type_list for y_type in y_type_list] + } for x_type in SUPPORTED_DTYPES for y_type in SUPPORTED_DTYPES] @unittest.skipIf(not paddle.is_compiled_with_npu(), diff --git a/python/paddle/fluid/tests/unittests/xpu/test_logical_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_logical_op_xpu.py index 21eb99fcfbf919..5f5e5caa5fe553 100755 --- a/python/paddle/fluid/tests/unittests/xpu/test_logical_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_logical_op_xpu.py @@ -25,6 +25,10 @@ from op_test_xpu import XPUOpTest from paddle.static import Program, program_guard +SUPPORTED_DTYPES = [ + bool, np.int8, np.int16, np.int32, np.int64, np.float32, np.float64 +] + TEST_META_OP_DATA = [{ 'op_str': 'logical_and', 'binary_op': True @@ -110,13 +114,13 @@ def run_static_xpu(x_np, y_np, op_str, binary_op=True): place = paddle.XPUPlace(0) exe = fluid.Executor(place) with fluid.program_guard(main_program, startup_program): - x = paddle.static.data(name='x', shape=x_np.shape, dtype='bool') + x = paddle.static.data(name='x', shape=x_np.shape, dtype=x_np.dtype) op = getattr(paddle, op_str) feed_list = {'x': x_np} if not binary_op: res = op(x) else: - y = paddle.static.data(name='y', shape=y_np.shape, dtype='bool') + y = paddle.static.data(name='y', shape=y_np.shape, dtype=y_np.dtype) feed_list['y'] = y_np res = op(x, y) exe.run(startup_program) @@ -128,17 +132,20 @@ def run_dygraph_xpu(x_np, y_np, op_str, binary_op=True): place = paddle.XPUPlace(0) paddle.disable_static(place) op = getattr(paddle, op_str) - x = paddle.to_tensor(x_np) + x = paddle.to_tensor(x_np, dtype=x_np.dtype) if not binary_op: dygraph_result = op(x) else: - y = paddle.to_tensor(y_np) + y = paddle.to_tensor(y_np, dtype=y_np.dtype) dygraph_result = op(x, y) return dygraph_result -def np_data_generator(np_shape, *args, **kwargs): - return np.random.choice(a=[True, False], size=np_shape).astype(bool) +def np_data_generator(np_shape, dtype, *args, **kwargs): + if dtype == bool: + return np.random.choice(a=[True, False], size=np_shape).astype(bool) + else: + return np.random.randn(*np_shape).astype(dtype) def test_xpu(unit_test, test_error=False): @@ -149,21 +156,25 @@ def test_xpu(unit_test, test_error=False): if test_error: META_DATA = dict(TEST_META_WRONG_SHAPE_DATA) for shape_data in META_DATA.values(): - meta_data['x_np'] = np_data_generator(shape_data['x_shape']) - meta_data['y_np'] = np_data_generator(shape_data['y_shape']) - if meta_data['binary_op'] and test_error: - # catch C++ Exception - unit_test.assertRaises(BaseException, run_static_xpu, - **meta_data) - continue - static_result = run_static_xpu(**meta_data) - dygraph_result = run_dygraph_xpu(**meta_data) - if meta_data['binary_op']: - np_result = np_op(meta_data['x_np'], meta_data['y_np']) - else: - np_result = np_op(meta_data['x_np']) - unit_test.assertTrue((static_result == np_result).all()) - unit_test.assertTrue((dygraph_result.numpy() == np_result).all()) + for data_type in SUPPORTED_DTYPES: + meta_data['x_np'] = np_data_generator( + shape_data['x_shape'], dtype=data_type) + meta_data['y_np'] = np_data_generator( + shape_data['y_shape'], dtype=data_type) + if meta_data['binary_op'] and test_error: + # catch C++ Exception + unit_test.assertRaises(BaseException, run_static_xpu, + **meta_data) + continue + static_result = run_static_xpu(**meta_data) + dygraph_result = run_dygraph_xpu(**meta_data) + if meta_data['binary_op']: + np_result = np_op(meta_data['x_np'], meta_data['y_np']) + else: + np_result = np_op(meta_data['x_np']) + unit_test.assertTrue((static_result == np_result).all()) + unit_test.assertTrue((dygraph_result.numpy() == np_result).all( + )) def test_type_error(unit_test, type_str_map): @@ -175,14 +186,14 @@ def check_type(op_str, x, y, binary_op): y = paddle.to_tensor(y) error_type = BaseException if binary_op: - if type_str_map['x'] != 'bool' or type_str_map['y'] != 'bool': + if type_str_map['x'] != type_str_map['y']: unit_test.assertRaises(error_type, op, x=x, y=y) if not fluid.in_dygraph_mode(): + error_type = TypeError unit_test.assertRaises(error_type, op, x=x, y=y, out=1) else: - if type_str_map['x'] != 'bool': - unit_test.assertRaises(error_type, op, x=x) if not fluid.in_dygraph_mode(): + error_type = TypeError unit_test.assertRaises(error_type, op, x=x, out=1) place = paddle.XPUPlace(0) @@ -208,12 +219,10 @@ def check_type(op_str, x, y, binary_op): def type_map_factory(): - x_type_list = ['float32', 'float64', 'int32', 'int64', 'bool'] - y_type_list = ['float32', 'float64', 'int32', 'int64', 'bool'] return [{ 'x': x_type, 'y': y_type - } for x_type in x_type_list for y_type in y_type_list] + } for x_type in SUPPORTED_DTYPES for y_type in SUPPORTED_DTYPES] @unittest.skipIf(not paddle.is_compiled_with_xpu(), From acdaaf88b089bbd057d25708ca2caf6068180196 Mon Sep 17 00:00:00 2001 From: LuJiafeng Date: Thu, 15 Jul 2021 11:46:46 +0000 Subject: [PATCH 7/7] correct error type --- python/paddle/fluid/tests/unittests/npu/test_logical_op_npu.py | 3 ++- python/paddle/fluid/tests/unittests/xpu/test_logical_op_xpu.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/npu/test_logical_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_logical_op_npu.py index 71e7c483125988..eb3b6876754a16 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_logical_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_logical_op_npu.py @@ -185,7 +185,7 @@ def test(unit_test, use_npu=False, test_error=False): def test_type_error(unit_test, use_npu, type_str_map): def check_type(op_str, x, y, binary_op): op = getattr(paddle, op_str) - error_type = TypeError + error_type = ValueError if isinstance(x, np.ndarray): x = paddle.to_tensor(x) y = paddle.to_tensor(y) @@ -198,6 +198,7 @@ def check_type(op_str, x, y, binary_op): unit_test.assertRaises(error_type, op, x=x, y=y, out=1) else: if not fluid.in_dygraph_mode(): + error_type = TypeError unit_test.assertRaises(error_type, op, x=x, out=1) place = paddle.CPUPlace() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_logical_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_logical_op_xpu.py index 5f5e5caa5fe553..7e7481bd90646c 100755 --- a/python/paddle/fluid/tests/unittests/xpu/test_logical_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_logical_op_xpu.py @@ -180,7 +180,7 @@ def test_xpu(unit_test, test_error=False): def test_type_error(unit_test, type_str_map): def check_type(op_str, x, y, binary_op): op = getattr(paddle, op_str) - error_type = TypeError + error_type = ValueError if isinstance(x, np.ndarray): x = paddle.to_tensor(x) y = paddle.to_tensor(y)