Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 10 additions & 14 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,8 +1483,7 @@ def clip(x, min=None, max=None, name=None):

def trace(x, offset=0, axis1=0, axis2=1, name=None):
"""
:alias_main: paddle.trace
:alias: paddle.trace,paddle.tensor.trace,paddle.tensor.math.trace
**trace**

This OP computes the sum along diagonals of the input tensor x.

Expand All @@ -1499,32 +1498,26 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None):
- If offset = 0, it is the main diagonal.
- If offset > 0, it is above the main diagonal.
- If offset < 0, it is below the main diagonal.
- Note that if offset is out of input's shape indicated by axis1 and axis2, 0 will be returned.

Args:
x(Variable): The input tensor x. Must be at least 2-dimensional. The input data type should be float32, float64, int32, int64.
x(Tensor): The input tensor x. Must be at least 2-dimensional. The input data type should be float32, float64, int32, int64.
offset(int, optional): Which diagonals in input tensor x will be taken. Default: 0 (main diagonals).
axis1(int, optional): The first axis with respect to take diagonal. Default: 0.
axis2(int, optional): The second axis with respect to take diagonal. Default: 1.
name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None.

Returns:
Variable: the output data type is the same as input data type.
Tensor: the output data type is the same as input data type.

Examples:
.. code-block:: python

import paddle
import numpy as np

case1 = np.random.randn(2, 3).astype('float32')
case2 = np.random.randn(3, 10, 10).astype('float32')
case3 = np.random.randn(3, 10, 5, 10).astype('float32')

paddle.disable_static()

case1 = paddle.to_tensor(case1)
case2 = paddle.to_tensor(case2)
case3 = paddle.to_tensor(case3)
case1 = paddle.randn([2, 3])
case2 = paddle.randn([3, 10, 10])
case3 = paddle.randn([3, 10, 5, 10])
data1 = paddle.trace(case1) # data1.shape = [1]
data2 = paddle.trace(case2, offset=1, axis1=1, axis2=2) # data2.shape = [3]
data3 = paddle.trace(case3, offset=-3, axis1=1, axis2=-1) # data2.shape = [3, 5]
Expand Down Expand Up @@ -1559,6 +1552,9 @@ def __check_input(input, offset, dim1, dim2):
"axis1 and axis2 cannot be the same axis." \
"But received axis1 = %d, axis2 = %d\n"%(axis1, axis2)

if in_dygraph_mode():
return core.ops.trace(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2)

if not in_dygraph_mode():
__check_input(input, offset, axis1, axis2)
helper = LayerHelper('trace', **locals())
Expand Down