Skip to content

Commit fbded73

Browse files
author
root
committed
fix trace op stack overflow
1 parent 67ed7e1 commit fbded73

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

python/paddle/tensor/math.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1657,12 +1657,6 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None):
16571657
data2 = paddle.trace(case2, offset=1, axis1=1, axis2=2) # data2.shape = [3]
16581658
data3 = paddle.trace(case3, offset=-3, axis1=1, axis2=-1) # data2.shape = [3, 5]
16591659
"""
1660-
if in_dygraph_mode():
1661-
return _C_ops.trace(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2)
1662-
1663-
inputs = {'Input': [x]}
1664-
attrs = {'offset': offset, 'axis1': axis1, 'axis2': axis2}
1665-
16661660
def __check_input(input, offset, dim1, dim2):
16671661
check_dtype(x.dtype, 'Input',
16681662
['int32', 'int64', 'float16', 'float32', 'float64'],
@@ -1677,11 +1671,11 @@ def __check_input(input, offset, dim1, dim2):
16771671
axis1_ = axis1 if axis1 >= 0 else len(input_shape) + axis1
16781672
axis2_ = axis2 if axis2 >= 0 else len(input_shape) + axis2
16791673

1680-
assert axis1_ < len(input_shape), \
1674+
assert ((0 <= axis1_) and (axis1_ < len(input_shape))), \
16811675
"The argument axis1 is out of range (expected to be in range of [%d, %d], but got %d).\n" \
16821676
% (-(len(input_shape)), len(input_shape) - 1, axis1)
16831677

1684-
assert axis2_ < len(input_shape), \
1678+
assert ((0 <= axis2_) and (axis2_ < len(input_shape))), \
16851679
"The argument axis2 is out of range (expected to be in range of [%d, %d], but got %d).\n" \
16861680
% (-(len(input_shape)), len(input_shape) - 1, axis2)
16871681

@@ -1691,6 +1685,11 @@ def __check_input(input, offset, dim1, dim2):
16911685
"But received axis1 = %d, axis2 = %d\n"%(axis1, axis2)
16921686

16931687
__check_input(input, offset, axis1, axis2)
1688+
if in_dygraph_mode():
1689+
return _C_ops.trace(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2)
1690+
1691+
inputs = {'Input': [x]}
1692+
attrs = {'offset': offset, 'axis1': axis1, 'axis2': axis2}
16941693
helper = LayerHelper('trace', **locals())
16951694

16961695
out = helper.create_variable_for_type_inference(dtype=x.dtype)

0 commit comments

Comments
 (0)