Skip to content

Commit ff062a4

Browse files
authored
fix output dtype for paddle.sum (#34313)
* support bool dtype for paddle.sum
1 parent a842828 commit ff062a4

File tree

3 files changed

+31
-58
lines changed

3 files changed

+31
-58
lines changed

python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def test_optim_break_in_while(x):
184184

185185
class TestContinueInFor(unittest.TestCase):
186186
def setUp(self):
187-
self.input = np.zeros((1)).astype('int32')
187+
self.input = np.zeros((1)).astype('int64')
188188
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
189189
) else fluid.CPUPlace()
190190
self.init_dygraph_func()

python/paddle/fluid/tests/unittests/test_reduce_op.py

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -748,37 +748,6 @@ def test_errors(self):
748748
self.assertRaises(TypeError, fluid.layers.reduce_sum, x2)
749749

750750

751-
class API_TestSumOpError(unittest.TestCase):
752-
def test_errors(self):
753-
def test_dtype1():
754-
with fluid.program_guard(fluid.Program(), fluid.Program()):
755-
data = fluid.data(name="data", shape=[10], dtype="float64")
756-
paddle.sum(data, dtype="float32")
757-
758-
self.assertRaises(ValueError, test_dtype1)
759-
760-
def test_dtype2():
761-
with fluid.program_guard(fluid.Program(), fluid.Program()):
762-
data = fluid.data(name="data", shape=[10], dtype="int64")
763-
paddle.sum(data, dtype="int32")
764-
765-
self.assertRaises(ValueError, test_dtype2)
766-
767-
def test_dtype3():
768-
with fluid.program_guard(fluid.Program(), fluid.Program()):
769-
data = fluid.data(name="data", shape=[10], dtype="float64")
770-
paddle.sum(data, dtype="int32")
771-
772-
self.assertRaises(ValueError, test_dtype3)
773-
774-
def test_type():
775-
with fluid.program_guard(fluid.Program(), fluid.Program()):
776-
data = fluid.data(name="data", shape=[10], dtype="int32")
777-
paddle.sum(data, dtype="bool")
778-
779-
self.assertRaises(TypeError, test_type)
780-
781-
782751
class API_TestSumOp(unittest.TestCase):
783752
def run_static(self,
784753
shape,
@@ -805,14 +774,26 @@ def test_static(self):
805774
shape = [10, 10]
806775
axis = 1
807776

777+
self.run_static(shape, "bool", axis, attr_dtype=None)
778+
self.run_static(shape, "bool", axis, attr_dtype="int32")
779+
self.run_static(shape, "bool", axis, attr_dtype="int64")
780+
808781
self.run_static(shape, "int32", axis, attr_dtype=None)
809782
self.run_static(shape, "int32", axis, attr_dtype="int32")
810783
self.run_static(shape, "int32", axis, attr_dtype="int64")
811784

785+
self.run_static(shape, "int64", axis, attr_dtype=None)
786+
self.run_static(shape, "int64", axis, attr_dtype="int64")
787+
self.run_static(shape, "int64", axis, attr_dtype="int32")
788+
812789
self.run_static(shape, "float32", axis, attr_dtype=None)
813790
self.run_static(shape, "float32", axis, attr_dtype="float32")
814791
self.run_static(shape, "float32", axis, attr_dtype="float64")
815792

793+
self.run_static(shape, "float64", axis, attr_dtype=None)
794+
self.run_static(shape, "float64", axis, attr_dtype="float32")
795+
self.run_static(shape, "float64", axis, attr_dtype="float64")
796+
816797
shape = [5, 5, 5]
817798
self.run_static(shape, "int32", (0, 1), attr_dtype="int32")
818799
self.run_static(

python/paddle/tensor/math.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -716,13 +716,15 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
716716
else:
717717
reduce_all_flag = False
718718

719-
dtype_flag = False
720-
if dtype is not None:
721-
if dtype in ['float64', 'int64']:
722-
if (convert_dtype(x.dtype) == "float32" and dtype == "float64") or \
723-
(convert_dtype(x.dtype) == "int32" and dtype == "int64"):
724-
dtype_flag = True
725-
719+
def get_dtype(x, dtype):
720+
if dtype is not None:
721+
return (True, dtype)
722+
src_type = convert_dtype(x.dtype)
723+
if src_type in ['bool','int32', 'int64']:
724+
return (True, 'int64')
725+
return (False, src_type)
726+
727+
dtype_flag, dtype = get_dtype(x, dtype)
726728
if in_dygraph_mode():
727729
axis = axis if axis != None and axis != [] else [0]
728730
if dtype_flag:
@@ -740,27 +742,17 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
740742
'reduce_all': reduce_all_flag
741743
}
742744

743-
if dtype is not None:
744-
if dtype in ['float64', 'int64']:
745-
if (convert_dtype(x.dtype) == "float32" and dtype == "float64") or \
746-
(convert_dtype(x.dtype) == "int32" and dtype == "int64"):
747-
attrs.update({
748-
'in_dtype': x.dtype,
749-
'out_dtype': convert_np_dtype_to_dtype_(dtype)
750-
})
745+
if dtype_flag:
746+
attrs.update({
747+
'in_dtype': x.dtype,
748+
'out_dtype': convert_np_dtype_to_dtype_(dtype)
749+
})
751750

752751
check_variable_and_dtype(
753-
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'sum')
754-
755-
if dtype is not None:
756-
check_dtype(dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'], 'sum')
757-
x_dtype = convert_dtype(x.dtype)
758-
759-
if (x_dtype == "float64" and dtype in ["float32", "int32"]) or \
760-
(x_dtype == "int64" and dtype == "int32"):
761-
raise ValueError("The input(x)'s dtype is {} but the attr(dtype) of sum is {}, "
762-
"which may cause data type overflows. Please reset attr(dtype) of sum."
763-
.format(x_dtype, dtype))
752+
x, 'x', ['bool', 'float16', 'float32', 'float64',
753+
'int32', 'int64', 'complex64', 'complex128',
754+
u'bool', u'float16', u'float32', u'float64',
755+
u'int32', u'int64', u'complex64', u'complex128'], 'sum')
764756

765757
check_type(axis, 'axis', (int, list, tuple, type(None)), 'sum')
766758

0 commit comments

Comments
 (0)