Skip to content

Commit 53f2a48

Browse files
[API Compatibility] Add swapaxes and swapdims as transpose alias (#74864)
* Add alias for transpose * Rename swapaxis -> swapaxes * Support axis0 axis1 param for swapaxes * Refine swapaxes * Fix test error and export swapaxes and swapdims * Support axis0 axis1 params for transpose * rerun ci * Fix
1 parent d83b10e commit 53f2a48

File tree

5 files changed

+220
-25
lines changed

5 files changed

+220
-25
lines changed

python/paddle/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,9 +864,11 @@ def __dir__(self):
864864
take_along_dim = take_along_axis
865865
clamp = clip
866866
ger = outer
867-
868867
div = divide
869868
div_ = divide_
869+
swapdims = transpose
870+
swapaxes = transpose
871+
870872

871873
__all__ = [
872874
'block_diag',
@@ -1182,6 +1184,8 @@ def __dir__(self):
11821184
'tanh',
11831185
'tanh_',
11841186
'transpose',
1187+
'swapaxes',
1188+
'swapdims',
11851189
'transpose_',
11861190
'permute',
11871191
'cauchy_',

python/paddle/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,8 @@
497497
# API alias
498498
div = divide
499499
div_ = divide_
500+
swapdims = transpose
501+
swapaxes = transpose
500502

501503
# this list used in math_op_patch.py for _binary_creator_
502504
tensor_method_func = [
@@ -728,6 +730,8 @@
728730
'stack',
729731
'strided_slice',
730732
'transpose',
733+
'swapaxes',
734+
'swapdims',
731735
'transpose_',
732736
'permute',
733737
'cauchy_',

python/paddle/utils/decorator_utils.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -445,30 +445,21 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
445445
if ("input" in kwargs) and ("x" not in kwargs):
446446
kwargs["x"] = kwargs.pop("input")
447447

448-
has_dim0 = "dim0" in kwargs or (
449-
len(args) > 1 and isinstance(args[1], int)
450-
)
451-
if has_dim0:
452-
dim0 = kwargs.pop(
453-
"dim0",
454-
args[1]
455-
if (len(args) > 1 and isinstance(args[1], int))
456-
else None,
457-
)
458-
dim1 = kwargs.pop(
459-
"dim1",
460-
args[2]
461-
if (len(args) > 2 and isinstance(args[2], int))
462-
else None,
463-
)
464-
465-
if dim0 is not None and dim1 is not None:
466-
ndim = kwargs["x"].ndim if "x" in kwargs else args[0].ndim
467-
perm = list(range(ndim))
468-
perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
469-
kwargs["perm"] = perm
470-
if len(args) > 1:
471-
args = (args[0],)
448+
dim0 = kwargs.pop("dim0", kwargs.pop("axis0", None))
449+
dim1 = kwargs.pop("dim1", kwargs.pop("axis1", None))
450+
451+
if dim0 is None and len(args) > 1 and isinstance(args[1], int):
452+
dim0 = args[1]
453+
if dim1 is None and len(args) > 2 and isinstance(args[2], int):
454+
dim1 = args[2]
455+
456+
if dim0 is not None and dim1 is not None:
457+
ndim = kwargs["x"].ndim if "x" in kwargs else args[0].ndim
458+
perm = list(range(ndim))
459+
perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
460+
kwargs["perm"] = perm
461+
if len(args) > 1:
462+
args = (args[0],)
472463

473464
return func(*args, **kwargs)
474465

test/legacy_test/test_swapaxes.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
from utils import dygraph_guard, static_guard
19+
20+
import paddle
21+
22+
23+
class TestSwapaxesCompatibility(unittest.TestCase):
24+
def setUp(self):
25+
self.places = [paddle.CPUPlace()]
26+
if paddle.base.core.is_compiled_with_cuda():
27+
self.places.append(paddle.CUDAPlace(0))
28+
self.func = paddle.swapaxes
29+
self.init_data()
30+
31+
def init_data(self):
32+
self.shape = [4, 5, 6]
33+
self.dtype = 'float32'
34+
self.dim0 = 0
35+
self.dim1 = 1
36+
self.perm = [1, 0, 2]
37+
38+
self.np_input = np.random.rand(*self.shape).astype(self.dtype)
39+
self.np_out = np.transpose(self.np_input, axes=self.perm)
40+
41+
def test_dygraph_compatibility(self):
42+
with dygraph_guard():
43+
for place in self.places:
44+
paddle.device.set_device(place)
45+
x = paddle.to_tensor(self.np_input)
46+
outs = []
47+
outs.append(paddle.swapaxes(x, perm=self.perm))
48+
outs.append(paddle.swapaxes(x=x, perm=self.perm))
49+
outs.append(paddle.swapaxes(input=x, perm=self.perm))
50+
outs.append(paddle.swapaxes(x, self.dim0, self.dim1))
51+
outs.append(
52+
paddle.swapaxes(x=x, axis0=self.dim0, axis1=self.dim1)
53+
)
54+
outs.append(
55+
paddle.swapaxes(input=x, axis0=self.dim0, axis1=self.dim1)
56+
)
57+
58+
outs.append(x.swapaxes(self.perm))
59+
outs.append(x.swapaxes(self.dim0, self.dim1))
60+
outs.append(x.swapaxes(perm=self.perm))
61+
outs.append(x.swapaxes(axis0=self.dim0, axis1=self.dim1))
62+
outs.append(x.swapaxes(self.dim0, axis1=self.dim1))
63+
64+
for out in outs:
65+
np.testing.assert_array_equal(self.np_out, out.numpy())
66+
67+
def test_static_compatibility(self):
68+
with static_guard():
69+
for place in self.places:
70+
main = paddle.static.Program()
71+
startup = paddle.static.Program()
72+
with paddle.base.program_guard(main, startup):
73+
x = paddle.static.data(
74+
name="x", shape=self.shape, dtype=self.dtype
75+
)
76+
outs = []
77+
outs.append(paddle.swapaxes(x, perm=self.perm))
78+
outs.append(paddle.swapaxes(x=x, perm=self.perm))
79+
outs.append(paddle.swapaxes(input=x, perm=self.perm))
80+
outs.append(paddle.swapaxes(x, self.dim0, self.dim1))
81+
outs.append(
82+
paddle.swapaxes(x=x, axis0=self.dim0, axis1=self.dim1)
83+
)
84+
outs.append(
85+
paddle.swapaxes(
86+
input=x, axis0=self.dim0, axis1=self.dim1
87+
)
88+
)
89+
90+
outs.append(x.swapaxes(self.perm))
91+
outs.append(x.swapaxes(self.dim0, self.dim1))
92+
outs.append(x.swapaxes(perm=self.perm))
93+
outs.append(x.swapaxes(axis0=self.dim0, axis1=self.dim1))
94+
outs.append(x.swapaxes(self.dim0, axis1=self.dim1))
95+
96+
exe = paddle.base.Executor(place)
97+
fetches = exe.run(
98+
main,
99+
feed={"x": self.np_input},
100+
fetch_list=outs,
101+
)
102+
for out in fetches:
103+
np.testing.assert_array_equal(self.np_out, out)
104+
105+
106+
if __name__ == "__main__":
107+
unittest.main()

test/legacy_test/test_swapdims.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
from utils import dygraph_guard, static_guard
19+
20+
import paddle
21+
22+
23+
class TestswapdimsCompatibility(unittest.TestCase):
24+
def setUp(self):
25+
self.places = [paddle.CPUPlace()]
26+
if paddle.base.core.is_compiled_with_cuda():
27+
self.places.append(paddle.CUDAPlace(0))
28+
self.func = paddle.swapdims
29+
self.init_data()
30+
31+
def init_data(self):
32+
self.shape = [4, 5, 6]
33+
self.dtype = 'float32'
34+
self.dim0 = 0
35+
self.dim1 = 1
36+
self.perm = [1, 0, 2]
37+
38+
self.np_input = np.random.rand(*self.shape).astype(self.dtype)
39+
self.np_out = np.transpose(self.np_input, axes=self.perm)
40+
41+
def test_dygraph_compatibility(self):
42+
with dygraph_guard():
43+
for place in self.places:
44+
paddle.device.set_device(place)
45+
x = paddle.to_tensor(self.np_input)
46+
outs = []
47+
outs.append(paddle.swapdims(x, self.dim0, self.dim1))
48+
outs.append(
49+
paddle.swapdims(input=x, dim0=self.dim0, dim1=self.dim1)
50+
)
51+
52+
outs.append(x.swapdims(self.dim0, self.dim1))
53+
outs.append(x.swapdims(dim0=self.dim0, dim1=self.dim1))
54+
outs.append(x.swapdims(self.dim0, dim1=self.dim1))
55+
56+
for out in outs:
57+
np.testing.assert_array_equal(self.np_out, out.numpy())
58+
59+
def test_static_compatibility(self):
60+
with static_guard():
61+
for place in self.places:
62+
main = paddle.static.Program()
63+
startup = paddle.static.Program()
64+
with paddle.base.program_guard(main, startup):
65+
x = paddle.static.data(
66+
name="x", shape=self.shape, dtype=self.dtype
67+
)
68+
outs = []
69+
outs.append(paddle.swapdims(x, self.dim0, self.dim1))
70+
outs.append(
71+
paddle.swapdims(input=x, dim0=self.dim0, dim1=self.dim1)
72+
)
73+
74+
outs.append(x.swapdims(self.dim0, self.dim1))
75+
outs.append(x.swapdims(dim0=self.dim0, dim1=self.dim1))
76+
outs.append(x.swapdims(self.dim0, dim1=self.dim1))
77+
78+
exe = paddle.base.Executor(place)
79+
fetches = exe.run(
80+
main,
81+
feed={"x": self.np_input},
82+
fetch_list=outs,
83+
)
84+
for out in fetches:
85+
np.testing.assert_array_equal(self.np_out, out)
86+
87+
88+
if __name__ == "__main__":
89+
unittest.main()

0 commit comments

Comments
 (0)