@@ -64,3 +64,39 @@ def _test_case_4():
6464 """
6565 )
6666 obj .run (pytorch_code , ["result" ])
67+
68+
69+ def test_case_5 ():
70+ pytorch_code = textwrap .dedent (
71+ """
72+ import torch
73+ a = torch.tensor([[[1., 2., 3., 4.], [5., 6., 7., 8.], [9., 10., 11., 12.], [13., 14., 15., 16.]],
74+ [[17., 18., 19., 20.], [21., 22., 23., 24.], [25., 26., 27., 28.], [29., 30., 31., 32.]],
75+ [[33., 34., 35., 36.], [37., 38., 39., 40.], [41., 42., 43., 44.], [45., 46., 47., 48.]]
76+ ])
77+ b = torch.tensor([[[4., 5., 6.], [2., 3., 4.], [3., 3., 3.], [2., 2., 2.]],
78+ [[8., 10., 11.], [5., 6., 8.], [4., 4., 4.], [1., 1., 1.]],
79+ [[12., 13., 15.], [9., 10., 11.], [6., 6., 6.], [3., 3., 3.]]
80+ ])
81+ result = a.bmm(b)
82+ """
83+ )
84+ obj .run (pytorch_code , ["result" ])
85+
86+
87+ def test_case_6 ():
88+ pytorch_code = textwrap .dedent (
89+ """
90+ import torch
91+ a = torch.tensor([[[1., 2., 3., 4.], [5., 6., 7., 8.], [9., 10., 11., 12.], [13., 14., 15., 16.]],
92+ [[17., 18., 19., 20.], [21., 22., 23., 24.], [25., 26., 27., 28.], [29., 30., 31., 32.]],
93+ [[33., 34., 35., 36.], [37., 38., 39., 40.], [41., 42., 43., 44.], [45., 46., 47., 48.]]
94+ ])
95+ b = torch.tensor([[[4., 5., 6.], [2., 3., 4.], [3., 3., 3.], [2., 2., 2.]],
96+ [[8., 10., 11.], [5., 6., 8.], [4., 4., 4.], [1., 1., 1.]],
97+ [[12., 13., 15.], [9., 10., 11.], [6., 6., 6.], [3., 3., 3.]]
98+ ])
99+ result = a.bmm(mat2=b)
100+ """
101+ )
102+ obj .run (pytorch_code , ["result" ])
0 commit comments