Skip to content

Commit cb5ae32

Browse files
authored
[API Compatibility] add msort, ravel, scatter_add, scatter_add_, tril, triu, bmm, nn.GELU, broadcast_shapes (#656)
* msort api * add ravel, scatter_add, scatter_add_, tril, triu, bmm, nn.GELU, broadcast_shapes * fix testcase * fix codestyle * fix accuracy
1 parent 3f3e71c commit cb5ae32

File tree

9 files changed

+338
-8
lines changed

9 files changed

+338
-8
lines changed

paconvert/global_var.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,24 @@ class GlobalManager:
153153

154154

155155
# shijie
156-
"torch.Tensor.scatter_reduce",
157-
"torch.scatter_reduce",
156+
"torch.msort",
157+
"torch.Tensor.msort",
158+
"torch.Tensor.ravel",
159+
"torch.ravel",
160+
"torch.Tensor.scatter_add",
161+
"torch.scatter_add",
162+
"torch.Tensor.scatter_add_",
163+
"torch.Tensor.tril",
164+
"torch.tril",
165+
"torch.Tensor.triu",
166+
"torch.triu",
167+
"torch.bmm",
168+
"torch.Tensor.bmm",
169+
"torch.nn.GELU",
170+
"torch.broadcast_shapes",
171+
"torch.Tensor.scatter_reduce",
172+
"torch.scatter_reduce",
173+
158174

159175
# yuyan
160176

tests/test_Tensor_bmm.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,39 @@ def _test_case_4():
6464
"""
6565
)
6666
obj.run(pytorch_code, ["result"])
67+
68+
69+
def test_case_5():
70+
pytorch_code = textwrap.dedent(
71+
"""
72+
import torch
73+
a = torch.tensor([[[1., 2., 3., 4.], [5., 6., 7., 8.], [9., 10., 11., 12.], [13., 14., 15., 16.]],
74+
[[17., 18., 19., 20.], [21., 22., 23., 24.], [25., 26., 27., 28.], [29., 30., 31., 32.]],
75+
[[33., 34., 35., 36.], [37., 38., 39., 40.], [41., 42., 43., 44.], [45., 46., 47., 48.]]
76+
])
77+
b = torch.tensor([[[4., 5., 6.], [2., 3., 4.], [3., 3., 3.], [2., 2., 2.]],
78+
[[8., 10., 11.], [5., 6., 8.], [4., 4., 4.], [1., 1., 1.]],
79+
[[12., 13., 15.], [9., 10., 11.], [6., 6., 6.], [3., 3., 3.]]
80+
])
81+
result = a.bmm(b)
82+
"""
83+
)
84+
obj.run(pytorch_code, ["result"])
85+
86+
87+
def test_case_6():
88+
pytorch_code = textwrap.dedent(
89+
"""
90+
import torch
91+
a = torch.tensor([[[1., 2., 3., 4.], [5., 6., 7., 8.], [9., 10., 11., 12.], [13., 14., 15., 16.]],
92+
[[17., 18., 19., 20.], [21., 22., 23., 24.], [25., 26., 27., 28.], [29., 30., 31., 32.]],
93+
[[33., 34., 35., 36.], [37., 38., 39., 40.], [41., 42., 43., 44.], [45., 46., 47., 48.]]
94+
])
95+
b = torch.tensor([[[4., 5., 6.], [2., 3., 4.], [3., 3., 3.], [2., 2., 2.]],
96+
[[8., 10., 11.], [5., 6., 8.], [4., 4., 4.], [1., 1., 1.]],
97+
[[12., 13., 15.], [9., 10., 11.], [6., 6., 6.], [3., 3., 3.]]
98+
])
99+
result = a.bmm(mat2=b)
100+
"""
101+
)
102+
obj.run(pytorch_code, ["result"])

tests/test_Tensor_ravel.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,14 @@ def test_case_3():
5050
"""
5151
)
5252
obj.run(pytorch_code, ["result"])
53+
54+
55+
def test_case_4():
56+
pytorch_code = textwrap.dedent(
57+
"""
58+
import torch
59+
a = torch.tensor([[4., 9., 10., 2.], [23., 12., 17., 18.]])
60+
result = a.ravel()
61+
"""
62+
)
63+
obj.run(pytorch_code, ["result"])

tests/test_Tensor_scatter_add.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,55 @@ def test_case_5():
7878
"""
7979
)
8080
obj.run(pytorch_code, ["result"])
81+
82+
83+
def test_case_6():
84+
pytorch_code = textwrap.dedent(
85+
"""
86+
import torch
87+
src = torch.ones((3, 5))
88+
index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2], [1, 2, 1, 2, 0]])
89+
input = torch.zeros(3, 5, dtype=src.dtype)
90+
result = input.scatter_add(0, index, src)
91+
"""
92+
)
93+
obj.run(pytorch_code, ["result"])
94+
95+
96+
def test_case_7():
97+
pytorch_code = textwrap.dedent(
98+
"""
99+
import torch
100+
src = torch.ones((3, 5))
101+
index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2], [1, 2, 1, 2, 0]])
102+
input = torch.zeros(3, 5, dtype=src.dtype)
103+
result = input.scatter_add(dim=1, index=index, src=src)
104+
"""
105+
)
106+
obj.run(pytorch_code, ["result"])
107+
108+
109+
def test_case_8():
110+
pytorch_code = textwrap.dedent(
111+
"""
112+
import torch
113+
src = torch.ones((3, 5))
114+
index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2], [1, 2, 1, 2, 0]])
115+
input = torch.zeros(3, 5, dtype=src.dtype)
116+
result = input.scatter_add(dim=0, index=index, src=src)
117+
"""
118+
)
119+
obj.run(pytorch_code, ["result"])
120+
121+
122+
def test_case_9():
123+
pytorch_code = textwrap.dedent(
124+
"""
125+
import torch
126+
src = torch.ones((3, 5))
127+
index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2], [1, 2, 1, 2, 0]])
128+
input = torch.zeros(3, 5, dtype=src.dtype)
129+
result = input.scatter_add(1, index, src)
130+
"""
131+
)
132+
obj.run(pytorch_code, ["result"])

tests/test_Tensor_scatter_add_.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,55 @@ def test_case_4():
6666
"""
6767
)
6868
obj.run(pytorch_code, ["result"])
69+
70+
71+
def test_case_5():
72+
pytorch_code = textwrap.dedent(
73+
"""
74+
import torch
75+
src = torch.ones((3, 5))
76+
index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2], [1, 2, 1, 2, 0]])
77+
input = torch.zeros(3, 5, dtype=src.dtype)
78+
result = input.scatter_add_(0, index, src)
79+
"""
80+
)
81+
obj.run(pytorch_code, ["result"])
82+
83+
84+
def test_case_6():
85+
pytorch_code = textwrap.dedent(
86+
"""
87+
import torch
88+
src = torch.ones((3, 5))
89+
index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2], [1, 2, 1, 2, 0]])
90+
input = torch.zeros(3, 5, dtype=src.dtype)
91+
result = input.scatter_add_(dim=1, index=index, src=src)
92+
"""
93+
)
94+
obj.run(pytorch_code, ["result"])
95+
96+
97+
def test_case_7():
98+
pytorch_code = textwrap.dedent(
99+
"""
100+
import torch
101+
src = torch.ones((3, 5))
102+
index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2], [1, 2, 1, 2, 0]])
103+
input = torch.zeros(3, 5, dtype=src.dtype)
104+
result = input.scatter_add_(dim=0, index=index, src=src)
105+
"""
106+
)
107+
obj.run(pytorch_code, ["result"])
108+
109+
110+
def test_case_8():
111+
pytorch_code = textwrap.dedent(
112+
"""
113+
import torch
114+
src = torch.ones((3, 5))
115+
index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2], [1, 2, 1, 2, 0]])
116+
input = torch.zeros(3, 5, dtype=src.dtype)
117+
result = input.scatter_add_(1, index, src)
118+
"""
119+
)
120+
obj.run(pytorch_code, ["result"])

tests/test_Tensor_tril.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,73 @@ def test_case_3():
5252
"""
5353
)
5454
obj.run(pytorch_code, ["result"])
55+
56+
57+
def test_case_4():
58+
pytorch_code = textwrap.dedent(
59+
"""
60+
import torch
61+
a = torch.tensor([[1., 3., 8., 11., 56.],
62+
[15., 30., 7., 14., 90.],
63+
[10., 313., 78., 110., 34.],
64+
[33., 23., 18., 9., 41.]])
65+
result = a.tril()
66+
"""
67+
)
68+
obj.run(pytorch_code, ["result"])
69+
70+
71+
def test_case_5():
72+
pytorch_code = textwrap.dedent(
73+
"""
74+
import torch
75+
a = torch.tensor([[1., 3., 8., 11., 56.],
76+
[15., 30., 7., 14., 90.],
77+
[10., 313., 78., 110., 34.],
78+
[33., 23., 18., 9., 41.]])
79+
result = a.tril(2)
80+
"""
81+
)
82+
obj.run(pytorch_code, ["result"])
83+
84+
85+
def test_case_6():
86+
pytorch_code = textwrap.dedent(
87+
"""
88+
import torch
89+
a = torch.tensor([[1., 3., 8., 11., 56.],
90+
[15., 30., 7., 14., 90.],
91+
[10., 313., 78., 110., 34.],
92+
[33., 23., 18., 9., 41.]])
93+
result = a.tril(diagonal=-1)
94+
"""
95+
)
96+
obj.run(pytorch_code, ["result"])
97+
98+
99+
def test_case_7():
100+
pytorch_code = textwrap.dedent(
101+
"""
102+
import torch
103+
a = torch.tensor([[1., 3., 8., 11., 56.],
104+
[15., 30., 7., 14., 90.],
105+
[10., 313., 78., 110., 34.],
106+
[33., 23., 18., 9., 41.]])
107+
result = a.tril(-3)
108+
"""
109+
)
110+
obj.run(pytorch_code, ["result"])
111+
112+
113+
def test_case_8():
114+
pytorch_code = textwrap.dedent(
115+
"""
116+
import torch
117+
a = torch.tensor([[1., 3., 8., 11., 56.],
118+
[15., 30., 7., 14., 90.],
119+
[10., 313., 78., 110., 34.],
120+
[33., 23., 18., 9., 41.]])
121+
result = a.tril(diagonal=3)
122+
"""
123+
)
124+
obj.run(pytorch_code, ["result"])

tests/test_Tensor_triu.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,73 @@ def test_case_3():
5252
"""
5353
)
5454
obj.run(pytorch_code, ["result"])
55+
56+
57+
def test_case_4():
58+
pytorch_code = textwrap.dedent(
59+
"""
60+
import torch
61+
a = torch.tensor([[1., 3., 8., 11., 56.],
62+
[15., 30., 7., 14., 90.],
63+
[10., 313., 78., 110., 34.],
64+
[33., 23., 18., 9., 41.]])
65+
result = a.triu()
66+
"""
67+
)
68+
obj.run(pytorch_code, ["result"])
69+
70+
71+
def test_case_5():
72+
pytorch_code = textwrap.dedent(
73+
"""
74+
import torch
75+
a = torch.tensor([[1., 3., 8., 11., 56.],
76+
[15., 30., 7., 14., 90.],
77+
[10., 313., 78., 110., 34.],
78+
[33., 23., 18., 9., 41.]])
79+
result = a.triu(1)
80+
"""
81+
)
82+
obj.run(pytorch_code, ["result"])
83+
84+
85+
def test_case_6():
86+
pytorch_code = textwrap.dedent(
87+
"""
88+
import torch
89+
a = torch.tensor([[1., 3., 8., 11., 56.],
90+
[15., 30., 7., 14., 90.],
91+
[10., 313., 78., 110., 34.],
92+
[33., 23., 18., 9., 41.]])
93+
result = a.triu(diagonal=-1)
94+
"""
95+
)
96+
obj.run(pytorch_code, ["result"])
97+
98+
99+
def test_case_7():
100+
pytorch_code = textwrap.dedent(
101+
"""
102+
import torch
103+
a = torch.tensor([[1., 3., 8., 11., 56.],
104+
[15., 30., 7., 14., 90.],
105+
[10., 313., 78., 110., 34.],
106+
[33., 23., 18., 9., 41.]])
107+
result = a.triu(-3)
108+
"""
109+
)
110+
obj.run(pytorch_code, ["result"])
111+
112+
113+
def test_case_8():
114+
pytorch_code = textwrap.dedent(
115+
"""
116+
import torch
117+
a = torch.tensor([[1., 3., 8., 11., 56.],
118+
[15., 30., 7., 14., 90.],
119+
[10., 313., 78., 110., 34.],
120+
[33., 23., 18., 9., 41.]])
121+
result = a.triu(diagonal=3)
122+
"""
123+
)
124+
obj.run(pytorch_code, ["result"])

tests/test_broadcast_shapes.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,4 @@ def test_case_6():
8686
result = torch.broadcast_shapes(*shapes)
8787
"""
8888
)
89-
obj.run(
90-
pytorch_code,
91-
["result"],
92-
unsupport=True,
93-
reason="The parameter *shapes is currently not supported.",
94-
)
89+
obj.run(pytorch_code, ["result"])

tests/test_nn_GELU.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,31 @@ def test_case_3():
6060
"""
6161
)
6262
obj.run(pytorch_code, ["result"])
63+
64+
65+
def test_case_4():
66+
pytorch_code = textwrap.dedent(
67+
"""
68+
import torch
69+
import torch.nn as nn
70+
x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913],
71+
[-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]])
72+
model = nn.GELU(approximate ='none')
73+
result = model(x)
74+
"""
75+
)
76+
obj.run(pytorch_code, ["result"], atol=1e-6, rtol=1e-5)
77+
78+
79+
def test_case_5():
80+
pytorch_code = textwrap.dedent(
81+
"""
82+
import torch
83+
import torch.nn as nn
84+
x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913],
85+
[-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]])
86+
model = nn.GELU('none')
87+
result = model(x)
88+
"""
89+
)
90+
obj.run(pytorch_code, ["result"], atol=1e-6, rtol=1e-5)

0 commit comments

Comments
 (0)