|
18 | 18 | import scipy.stats |
19 | 19 | from distribution.config import ATOL, DEVICES, RTOL |
20 | 20 | from parameterize import TEST_CASE_NAME, parameterize_cls, place |
21 | | -from paddle.pir_utils import test_with_pir_api |
22 | 21 |
|
23 | 22 | import paddle |
24 | 23 |
|
@@ -140,94 +139,91 @@ def setUp(self): |
140 | 139 | self._paddle_diric = paddle.distribution.Dirichlet(conc) |
141 | 140 | self.feeds = {'conc': self.concentration} |
142 | 141 |
|
143 | | - @test_with_pir_api |
144 | 142 | def test_mean(self): |
145 | | - with paddle.static.program_guard(self.program): |
146 | | - [out] = self.executor.run( |
147 | | - self.program, |
148 | | - feed=self.feeds, |
149 | | - fetch_list=[self._paddle_diric.mean], |
150 | | - ) |
151 | | - np.testing.assert_allclose( |
152 | | - out, |
153 | | - scipy.stats.dirichlet.mean(self.concentration), |
154 | | - rtol=RTOL.get(str(self.concentration.dtype)), |
155 | | - atol=ATOL.get(str(self.concentration.dtype)), |
156 | | - ) |
| 143 | + with paddle.pir_utils.IrGuard(): |
| 144 | + with paddle.static.program_guard(self.program): |
| 145 | + [out] = self.executor.run( |
| 146 | + self.program, |
| 147 | + feed=self.feeds, |
| 148 | + fetch_list=[self._paddle_diric.mean], |
| 149 | + ) |
| 150 | + np.testing.assert_allclose( |
| 151 | + out, |
| 152 | + scipy.stats.dirichlet.mean(self.concentration), |
| 153 | + rtol=RTOL.get(str(self.concentration.dtype)), |
| 154 | + atol=ATOL.get(str(self.concentration.dtype)), |
| 155 | + ) |
157 | 156 |
|
158 | | - @test_with_pir_api |
159 | 157 | def test_variance(self): |
160 | | - with paddle.static.program_guard(self.program): |
161 | | - [out] = self.executor.run( |
162 | | - self.program, |
163 | | - feed=self.feeds, |
164 | | - fetch_list=[self._paddle_diric.variance], |
165 | | - ) |
166 | | - np.testing.assert_allclose( |
167 | | - out, |
168 | | - scipy.stats.dirichlet.var(self.concentration), |
169 | | - rtol=RTOL.get(str(self.concentration.dtype)), |
170 | | - atol=ATOL.get(str(self.concentration.dtype)), |
171 | | - ) |
| 158 | + with paddle.pir_utils.IrGuard(): |
| 159 | + with paddle.static.program_guard(self.program): |
| 160 | + [out] = self.executor.run( |
| 161 | + self.program, |
| 162 | + feed=self.feeds, |
| 163 | + fetch_list=[self._paddle_diric.variance], |
| 164 | + ) |
| 165 | + np.testing.assert_allclose( |
| 166 | + out, |
| 167 | + scipy.stats.dirichlet.var(self.concentration), |
| 168 | + rtol=RTOL.get(str(self.concentration.dtype)), |
| 169 | + atol=ATOL.get(str(self.concentration.dtype)), |
| 170 | + ) |
172 | 171 |
|
173 | | - @test_with_pir_api |
174 | 172 | def test_prob(self): |
175 | | - with paddle.static.program_guard(self.program): |
176 | | - random_number = np.random.rand(*self.concentration.shape) |
177 | | - random_number = random_number / random_number.sum() |
178 | | - feeds = dict(self.feeds, value=random_number) |
179 | | - value = paddle.static.data( |
180 | | - 'value', random_number.shape, random_number.dtype |
181 | | - ) |
182 | | - out = self._paddle_diric.prob(value) |
183 | | - [out] = self.executor.run( |
184 | | - self.program, feed=feeds, fetch_list=[out] |
185 | | - ) |
186 | | - np.testing.assert_allclose( |
187 | | - out, |
188 | | - scipy.stats.dirichlet.pdf(random_number, self.concentration), |
189 | | - rtol=RTOL.get(str(self.concentration.dtype)), |
190 | | - atol=ATOL.get(str(self.concentration.dtype)), |
191 | | - ) |
| 173 | + with paddle.pir_utils.IrGuard(): |
| 174 | + with paddle.static.program_guard(self.program): |
| 175 | + random_number = np.random.rand(*self.concentration.shape) |
| 176 | + random_number = random_number / random_number.sum() |
| 177 | + feeds = dict(self.feeds, value=random_number) |
| 178 | + value = paddle.static.data( |
| 179 | + 'value', random_number.shape, random_number.dtype |
| 180 | + ) |
| 181 | + out = self._paddle_diric.prob(value) |
| 182 | + [out] = self.executor.run( |
| 183 | + self.program, feed=feeds, fetch_list=[out] |
| 184 | + ) |
| 185 | + np.testing.assert_allclose( |
| 186 | + out, |
| 187 | + scipy.stats.dirichlet.pdf( |
| 188 | + random_number, self.concentration |
| 189 | + ), |
| 190 | + rtol=RTOL.get(str(self.concentration.dtype)), |
| 191 | + atol=ATOL.get(str(self.concentration.dtype)), |
| 192 | + ) |
192 | 193 |
|
193 | | - @test_with_pir_api |
194 | 194 | def test_log_prob(self): |
195 | | - with paddle.static.program_guard(self.program): |
196 | | - random_number = np.random.rand(*self.concentration.shape) |
197 | | - random_number = random_number / random_number.sum() |
198 | | - feeds = dict(self.feeds, value=random_number) |
199 | | - value = paddle.static.data( |
200 | | - 'value', random_number.shape, random_number.dtype |
201 | | - ) |
202 | | - out = self._paddle_diric.log_prob(value) |
203 | | - [out] = self.executor.run( |
204 | | - self.program, feed=feeds, fetch_list=[out] |
205 | | - ) |
206 | | - np.testing.assert_allclose( |
207 | | - out, |
208 | | - scipy.stats.dirichlet.logpdf(random_number, self.concentration), |
209 | | - rtol=RTOL.get(str(self.concentration.dtype)), |
210 | | - atol=ATOL.get(str(self.concentration.dtype)), |
211 | | - ) |
| 195 | + with paddle.pir_utils.IrGuard(): |
| 196 | + with paddle.static.program_guard(self.program): |
| 197 | + random_number = np.random.rand(*self.concentration.shape) |
| 198 | + random_number = random_number / random_number.sum() |
| 199 | + feeds = dict(self.feeds, value=random_number) |
| 200 | + value = paddle.static.data( |
| 201 | + 'value', random_number.shape, random_number.dtype |
| 202 | + ) |
| 203 | + out = self._paddle_diric.log_prob(value) |
| 204 | + [out] = self.executor.run( |
| 205 | + self.program, feed=feeds, fetch_list=[out] |
| 206 | + ) |
| 207 | + np.testing.assert_allclose( |
| 208 | + out, |
| 209 | + scipy.stats.dirichlet.logpdf( |
| 210 | + random_number, self.concentration |
| 211 | + ), |
| 212 | + rtol=RTOL.get(str(self.concentration.dtype)), |
| 213 | + atol=ATOL.get(str(self.concentration.dtype)), |
| 214 | + ) |
212 | 215 |
|
213 | | - @test_with_pir_api |
214 | 216 | def test_entropy(self): |
215 | | - with paddle.static.program_guard(self.program): |
216 | | - [out] = self.executor.run( |
217 | | - self.program, |
218 | | - feed=self.feeds, |
219 | | - fetch_list=[self._paddle_diric.entropy()], |
220 | | - ) |
221 | | - np.testing.assert_allclose( |
222 | | - out, |
223 | | - scipy.stats.dirichlet.entropy(self.concentration), |
224 | | - rtol=RTOL.get(str(self.concentration.dtype)), |
225 | | - atol=ATOL.get(str(self.concentration.dtype)), |
226 | | - ) |
227 | | - |
228 | | - def test_all(self): |
229 | | - self._test_mean(self.program) |
230 | | - self._test_variance(self.program) |
231 | | - self._test_prob(self.program) |
232 | | - self._test_log_prob(self.program) |
233 | | - self._test_entropy(self.program) |
| 217 | + with paddle.pir_utils.IrGuard(): |
| 218 | + with paddle.static.program_guard(self.program): |
| 219 | + [out] = self.executor.run( |
| 220 | + self.program, |
| 221 | + feed=self.feeds, |
| 222 | + fetch_list=[self._paddle_diric.entropy()], |
| 223 | + ) |
| 224 | + np.testing.assert_allclose( |
| 225 | + out, |
| 226 | + scipy.stats.dirichlet.entropy(self.concentration), |
| 227 | + rtol=RTOL.get(str(self.concentration.dtype)), |
| 228 | + atol=ATOL.get(str(self.concentration.dtype)), |
| 229 | + ) |
0 commit comments