Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_optim_break_in_while(x):

class TestContinueInFor(unittest.TestCase):
def setUp(self):
self.input = np.zeros((1)).astype('int32')
self.input = np.zeros((1)).astype('int64')
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.init_dygraph_func()
Expand Down
43 changes: 12 additions & 31 deletions python/paddle/fluid/tests/unittests/test_reduce_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,37 +748,6 @@ def test_errors(self):
self.assertRaises(TypeError, fluid.layers.reduce_sum, x2)


class API_TestSumOpError(unittest.TestCase):
def test_errors(self):
def test_dtype1():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float64")
paddle.sum(data, dtype="float32")

self.assertRaises(ValueError, test_dtype1)

def test_dtype2():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="int64")
paddle.sum(data, dtype="int32")

self.assertRaises(ValueError, test_dtype2)

def test_dtype3():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float64")
paddle.sum(data, dtype="int32")

self.assertRaises(ValueError, test_dtype3)

def test_type():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="int32")
paddle.sum(data, dtype="bool")

self.assertRaises(TypeError, test_type)


class API_TestSumOp(unittest.TestCase):
def run_static(self,
shape,
Expand All @@ -805,14 +774,26 @@ def test_static(self):
shape = [10, 10]
axis = 1

self.run_static(shape, "bool", axis, attr_dtype=None)
self.run_static(shape, "bool", axis, attr_dtype="int32")
self.run_static(shape, "bool", axis, attr_dtype="int64")

self.run_static(shape, "int32", axis, attr_dtype=None)
self.run_static(shape, "int32", axis, attr_dtype="int32")
self.run_static(shape, "int32", axis, attr_dtype="int64")

self.run_static(shape, "int64", axis, attr_dtype=None)
self.run_static(shape, "int64", axis, attr_dtype="int64")
self.run_static(shape, "int64", axis, attr_dtype="int32")

self.run_static(shape, "float32", axis, attr_dtype=None)
self.run_static(shape, "float32", axis, attr_dtype="float32")
self.run_static(shape, "float32", axis, attr_dtype="float64")

self.run_static(shape, "float64", axis, attr_dtype=None)
self.run_static(shape, "float64", axis, attr_dtype="float32")
self.run_static(shape, "float64", axis, attr_dtype="float64")

shape = [5, 5, 5]
self.run_static(shape, "int32", (0, 1), attr_dtype="int32")
self.run_static(
Expand Down
44 changes: 18 additions & 26 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,13 +716,15 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
else:
reduce_all_flag = False

dtype_flag = False
if dtype is not None:
if dtype in ['float64', 'int64']:
if (convert_dtype(x.dtype) == "float32" and dtype == "float64") or \
(convert_dtype(x.dtype) == "int32" and dtype == "int64"):
dtype_flag = True

def get_dtype(x, dtype):
if dtype is not None:
return (True, dtype)
src_type = convert_dtype(x.dtype)
if src_type in ['bool','int32', 'int64']:
return (True, 'int64')
return (False, src_type)

dtype_flag, dtype = get_dtype(x, dtype)
if in_dygraph_mode():
axis = axis if axis != None and axis != [] else [0]
if dtype_flag:
Expand All @@ -740,27 +742,17 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
'reduce_all': reduce_all_flag
}

if dtype is not None:
if dtype in ['float64', 'int64']:
if (convert_dtype(x.dtype) == "float32" and dtype == "float64") or \
(convert_dtype(x.dtype) == "int32" and dtype == "int64"):
attrs.update({
'in_dtype': x.dtype,
'out_dtype': convert_np_dtype_to_dtype_(dtype)
})
if dtype_flag:
attrs.update({
'in_dtype': x.dtype,
'out_dtype': convert_np_dtype_to_dtype_(dtype)
})

check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'sum')

if dtype is not None:
check_dtype(dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'], 'sum')
x_dtype = convert_dtype(x.dtype)

if (x_dtype == "float64" and dtype in ["float32", "int32"]) or \
(x_dtype == "int64" and dtype == "int32"):
raise ValueError("The input(x)'s dtype is {} but the attr(dtype) of sum is {}, "
"which may cause data type overflows. Please reset attr(dtype) of sum."
.format(x_dtype, dtype))
x, 'x', ['bool', 'float16', 'float32', 'float64',
'int32', 'int64', 'complex64', 'complex128',
u'bool', u'float16', u'float32', u'float64',
u'int32', u'int64', u'complex64', u'complex128'], 'sum')

check_type(axis, 'axis', (int, list, tuple, type(None)), 'sum')

Expand Down