1717import numpy as np
1818
1919import paddle
20- from paddle .static import Program , program_guard
20+ from paddle .pir_utils import test_with_pir_api
2121
2222
2323class TestMultiplyApi (unittest .TestCase ):
2424 def _run_static_graph_case (self , x_data , y_data ):
25- with program_guard (Program (), Program ()):
25+ with paddle .static .program_guard (
26+ paddle .static .Program (), paddle .static .Program ()
27+ ):
2628 paddle .enable_static ()
2729 x = paddle .static .data (
2830 name = 'x' , shape = x_data .shape , dtype = x_data .dtype
@@ -53,45 +55,52 @@ def _run_dynamic_graph_case(self, x_data, y_data):
5355 res = paddle .inner (x , y )
5456 return res .numpy ()
5557
56- def test_multiply (self ):
57- np .random .seed (7 )
58-
58+ @test_with_pir_api
59+ def test_multiply_static_case1 (self ):
5960 # test static computation graph: 3-d array
6061 x_data = np .random .rand (2 , 10 , 10 ).astype (np .float64 )
6162 y_data = np .random .rand (2 , 5 , 10 ).astype (np .float64 )
6263 res = self ._run_static_graph_case (x_data , y_data )
6364 np .testing .assert_allclose (res , np .inner (x_data , y_data ), rtol = 1e-05 )
6465
66+ @test_with_pir_api
67+ def test_multiply_static_case2 (self ):
6568 # test static computation graph: 2-d array
6669 x_data = np .random .rand (200 , 5 ).astype (np .float64 )
6770 y_data = np .random .rand (50 , 5 ).astype (np .float64 )
6871 res = self ._run_static_graph_case (x_data , y_data )
6972 np .testing .assert_allclose (res , np .inner (x_data , y_data ), rtol = 1e-05 )
7073
74+ @test_with_pir_api
75+ def test_multiply_static_case3 (self ):
7176 # test static computation graph: 1-d array
7277 x_data = np .random .rand (50 ).astype (np .float64 )
7378 y_data = np .random .rand (50 ).astype (np .float64 )
7479 res = self ._run_static_graph_case (x_data , y_data )
7580 np .testing .assert_allclose (res , np .inner (x_data , y_data ), rtol = 1e-05 )
7681
82+ def test_multiply_dynamic_case1 (self ):
7783 # test dynamic computation graph: 3-d array
7884 x_data = np .random .rand (5 , 10 , 10 ).astype (np .float64 )
7985 y_data = np .random .rand (2 , 10 ).astype (np .float64 )
8086 res = self ._run_dynamic_graph_case (x_data , y_data )
8187 np .testing .assert_allclose (res , np .inner (x_data , y_data ), rtol = 1e-05 )
8288
89+ def test_multiply_dynamic_case2 (self ):
8390 # test dynamic computation graph: 2-d array
8491 x_data = np .random .rand (20 , 50 ).astype (np .float64 )
8592 y_data = np .random .rand (50 ).astype (np .float64 )
8693 res = self ._run_dynamic_graph_case (x_data , y_data )
8794 np .testing .assert_allclose (res , np .inner (x_data , y_data ), rtol = 1e-05 )
8895
96+ def test_multiply_dynamic_case3 (self ):
8997 # test dynamic computation graph: Scalar
9098 x_data = np .random .rand (20 , 10 ).astype (np .float32 )
9199 y_data = np .random .rand (1 ).astype (np .float32 ).item ()
92100 res = self ._run_dynamic_graph_case (x_data , y_data )
93101 np .testing .assert_allclose (res , np .inner (x_data , y_data ), rtol = 1e-05 )
94102
103+ def test_multiply_dynamic_case4 (self ):
95104 # test dynamic computation graph: 2-d array Complex
96105 x_data = np .random .rand (20 , 50 ).astype (
97106 np .float64
@@ -102,6 +111,7 @@ def test_multiply(self):
102111 res = self ._run_dynamic_graph_case (x_data , y_data )
103112 np .testing .assert_allclose (res , np .inner (x_data , y_data ), rtol = 1e-05 )
104113
114+ def test_multiply_dynamic_case5 (self ):
105115 # test dynamic computation graph: 3-d array Complex
106116 x_data = np .random .rand (5 , 10 , 10 ).astype (
107117 np .float64
@@ -114,41 +124,49 @@ def test_multiply(self):
114124
115125
116126class TestMultiplyError (unittest .TestCase ):
117- def test_errors (self ):
127+ def test_errors_static_case1 (self ):
118128 # test static computation graph: dtype can not be int8
119129 paddle .enable_static ()
120- with program_guard (Program (), Program ()):
130+ with paddle .static .program_guard (
131+ paddle .static .Program (), paddle .static .Program ()
132+ ):
121133 x = paddle .static .data (name = 'x' , shape = [100 ], dtype = np .int8 )
122134 y = paddle .static .data (name = 'y' , shape = [100 ], dtype = np .int8 )
123135 self .assertRaises (TypeError , paddle .inner , x , y )
124136
137+ def test_errors_static_case2 (self ):
125138 # test static computation graph: inputs must be broadcastable
126- with program_guard (Program (), Program ()):
139+ paddle .enable_static ()
140+ with paddle .static .program_guard (
141+ paddle .static .Program (), paddle .static .Program ()
142+ ):
127143 x = paddle .static .data (name = 'x' , shape = [20 , 50 ], dtype = np .float64 )
128144 y = paddle .static .data (name = 'y' , shape = [20 ], dtype = np .float64 )
129145 self .assertRaises (ValueError , paddle .inner , x , y )
130146
131- np .random .seed (7 )
132-
147+ def test_errors_dynamic_case1 (self ):
133148 # test dynamic computation graph: inputs must be broadcastable
134149 x_data = np .random .rand (20 , 5 )
135150 y_data = np .random .rand (10 , 2 )
136151 x = paddle .to_tensor (x_data )
137152 y = paddle .to_tensor (y_data )
138153 self .assertRaises (ValueError , paddle .inner , x , y )
139154
155+ def test_errors_dynamic_case2 (self ):
140156 # test dynamic computation graph: dtype must be Tensor type
141157 x_data = np .random .randn (200 ).astype (np .float64 )
142158 y_data = np .random .randn (200 ).astype (np .float64 )
143159 y = paddle .to_tensor (y_data )
144160 self .assertRaises (TypeError , paddle .inner , x_data , y )
145161
162+ def test_errors_dynamic_case3 (self ):
146163 # test dynamic computation graph: dtype must be Tensor type
147164 x_data = np .random .randn (200 ).astype (np .float64 )
148165 y_data = np .random .randn (200 ).astype (np .float64 )
149166 x = paddle .to_tensor (x_data )
150167 self .assertRaises (TypeError , paddle .inner , x , y_data )
151168
169+ def test_errors_dynamic_case4 (self ):
152170 # test dynamic computation graph: dtype must be Tensor type
153171 x_data = np .random .randn (200 ).astype (np .float32 )
154172 y_data = np .random .randn (200 ).astype (np .float32 )
0 commit comments