diff --git a/paddle/phi/kernels/cpu/compare_kernel.cc b/paddle/phi/kernels/cpu/compare_kernel.cc index 24b4615daa58c2..b962f56cf9d2a7 100644 --- a/paddle/phi/kernels/cpu/compare_kernel.cc +++ b/paddle/phi/kernels/cpu/compare_kernel.cc @@ -110,8 +110,10 @@ PD_REGISTER_KERNEL(equal_all, ALL_LAYOUT, \ phi::func##Kernel, \ bool, \ - int16_t, \ int, \ + uint8_t, \ + int8_t, \ + int16_t, \ int64_t, \ float, \ double, \ @@ -119,6 +121,7 @@ PD_REGISTER_KERNEL(equal_all, phi::dtype::bfloat16) { \ kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \ } + PD_REGISTER_COMPARE_KERNEL(less_than, LessThan) PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual) PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan) diff --git a/paddle/phi/kernels/kps/compare_kernel.cu b/paddle/phi/kernels/kps/compare_kernel.cu index 14bb86b4753207..b10412ed1358c4 100644 --- a/paddle/phi/kernels/kps/compare_kernel.cu +++ b/paddle/phi/kernels/kps/compare_kernel.cu @@ -150,8 +150,10 @@ PD_REGISTER_KERNEL(equal_all, ALL_LAYOUT, \ phi::func##Kernel, \ bool, \ - int16_t, \ int, \ + uint8_t, \ + int8_t, \ + int16_t, \ int64_t, \ float, \ double, \ diff --git a/paddle/phi/kernels/legacy/cpu/compare_kernel.cc b/paddle/phi/kernels/legacy/cpu/compare_kernel.cc index d9760398af7cc6..66bbb806adb676 100644 --- a/paddle/phi/kernels/legacy/cpu/compare_kernel.cc +++ b/paddle/phi/kernels/legacy/cpu/compare_kernel.cc @@ -115,6 +115,8 @@ PD_REGISTER_KERNEL(less_than_raw, ALL_LAYOUT, phi::LessThanRawKernel, bool, + uint8_t, + int8_t, int16_t, int, int64_t, @@ -131,6 +133,8 @@ PD_REGISTER_KERNEL(less_than_raw, ALL_LAYOUT, \ phi::func##RawKernel, \ bool, \ + uint8_t, \ + int8_t, \ int16_t, \ int, \ int64_t, \ diff --git a/paddle/phi/kernels/legacy/kps/compare_kernel.cu b/paddle/phi/kernels/legacy/kps/compare_kernel.cu index 67bd491738346e..429cff41886a1f 100644 --- a/paddle/phi/kernels/legacy/kps/compare_kernel.cu +++ b/paddle/phi/kernels/legacy/kps/compare_kernel.cu @@ -139,6 +139,8 @@ PD_REGISTER_KERNEL(less_than_raw, ALL_LAYOUT, phi::LessThanRawKernel, bool, + uint8_t, + int8_t, int16_t, int, int64_t, @@ -155,8 +157,10 @@ PD_REGISTER_KERNEL(less_than_raw, ALL_LAYOUT, \ phi::func##RawKernel, \ bool, \ + uint8_t, \ int16_t, \ int, \ + int8_t, \ int64_t, \ float, \ double, \ diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 9b50993b891667..d30fe3d0b90875 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -512,8 +512,8 @@ def equal(x, y, name=None): The output has no gradient. Args: - x (Tensor): Tensor, data type is bool, float16, float32, float64, int32, int64. - y (Tensor): Tensor, data type is bool, float16, float32, float64, int32, int64. + x (Tensor): Tensor, data type is bool, float16, float32, float64, uint8, int8, int16, int32, int64. + y (Tensor): Tensor, data type is bool, float16, float32, float64, uint8, int8, int16, int32, int64. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -553,6 +553,9 @@ def equal(x, y, name=None): "float16", "float32", "float64", + "uint8", + "int8", + "int16", "int32", "int64", "uint16", @@ -567,6 +570,9 @@ def equal(x, y, name=None): "float16", "float32", "float64", + "uint8", + "int8", + "int16", "int32", "int64", "uint16", @@ -611,8 +617,8 @@ def greater_equal(x, y, name=None): The output has no gradient. Args: - x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64. - y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64. + x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64. + y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -641,6 +647,9 @@ def greater_equal(x, y, name=None): "float16", "float32", "float64", + "uint8", + "int8", + "int16", "int32", "int64", "uint16", @@ -655,6 +664,9 @@ def greater_equal(x, y, name=None): "float16", "float32", "float64", + "uint8", + "int8", + "int16", "int32", "int64", "uint16", @@ -699,8 +711,8 @@ def greater_than(x, y, name=None): The output has no gradient. Args: - x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64. - y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64. + x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64. + y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -729,6 +741,9 @@ def greater_than(x, y, name=None): "float16", "float32", "float64", + "uint8", + "int8", + "int16", "int32", "int64", "uint16", @@ -743,6 +758,9 @@ def greater_than(x, y, name=None): "float16", "float32", "float64", + "uint8", + "int8", + "int16", "int32", "int64", "uint16", @@ -787,8 +805,8 @@ def less_equal(x, y, name=None): The output has no gradient. Args: - x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64. - y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64. + x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64. + y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -818,6 +836,9 @@ def less_equal(x, y, name=None): "float16", "float32", "float64", + "uint8", + "int8", + "int16", "int32", "int64", "uint16", @@ -832,6 +853,9 @@ def less_equal(x, y, name=None): "float16", "float32", "float64", + "uint8", + "int8", + "int16", "int32", "int64", "uint16", @@ -876,8 +900,8 @@ def less_than(x, y, name=None): The output has no gradient. Args: - x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64. - y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64. + x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64. + y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -907,6 +931,9 @@ def less_than(x, y, name=None): "float16", "float32", "float64", + "uint8", + "int8", + "int16", "int32", "int64", "uint16", @@ -921,6 +948,9 @@ def less_than(x, y, name=None): "float16", "float32", "float64", + "uint8", + "int8", + "int16", "int32", "int64", "uint16", @@ -965,8 +995,8 @@ def not_equal(x, y, name=None): The output has no gradient. Args: - x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64. - y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64. + x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, uint8, int8, int16, int32, int64. + y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, uint8, int8, int16, int32, int64. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -996,6 +1026,9 @@ def not_equal(x, y, name=None): "float16", "float32", "float64", + "uint8", + "int8", + "int16", "int32", "int64", "uint16", @@ -1010,6 +1043,9 @@ def not_equal(x, y, name=None): "float16", "float32", "float64", + "uint8", + "int8", + "int16", "int32", "int64", "uint16", diff --git a/test/legacy_test/test_compare_op.py b/test/legacy_test/test_compare_op.py index 91dce088ef88ef..79aa2736eeb0cb 100755 --- a/test/legacy_test/test_compare_op.py +++ b/test/legacy_test/test_compare_op.py @@ -38,24 +38,35 @@ def setUp(self): def test_output(self): self.check_output(check_cinn=True, check_pir=check_pir) - def test_errors(self): + def test_int16_support(self): paddle.enable_static() with paddle.static.program_guard( paddle.static.Program(), paddle.static.Program() ): - x = paddle.static.data(name='x', shape=[-1, 2], dtype='int32') - y = paddle.static.data(name='y', shape=[-1, 2], dtype='int32') a = paddle.static.data(name='a', shape=[-1, 2], dtype='int16') + b = paddle.static.data(name='b', shape=[-1, 2], dtype='int16') op = eval("paddle.%s" % self.op_type) - self.assertRaises(TypeError, op, x=x, y=a) - self.assertRaises(TypeError, op, x=a, y=y) + + try: + result = op(x=a, y=b) + except TypeError: + self.fail("TypeError should not be raised for int16 inputs") cls_name = f"{op_type}_{typename}" Cls.__name__ = cls_name globals()[cls_name] = Cls -for _type_name in {'float32', 'float64', 'int32', 'int64', 'float16'}: +for _type_name in { + 'float32', + 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', + 'float16', +}: if _type_name == 'float64' and core.is_compiled_with_rocm(): _type_name = 'float32' if _type_name == 'float16' and (not core.is_compiled_with_cuda()): @@ -513,7 +524,7 @@ def test_check_output(self): class TestCompareOpError(unittest.TestCase): - def test_errors(self): + def test_int16_support(self): paddle.enable_static() with paddle.static.program_guard( paddle.static.Program(), paddle.static.Program()