|
24 | 24 | import paddle |
25 | 25 | import paddle.fluid as fluid |
26 | 26 | import paddle.nn.functional as F |
27 | | -from paddle.fluid.framework import _test_eager_guard |
28 | 27 | from paddle.incubate.autograd.utils import as_tensors |
29 | 28 |
|
30 | 29 |
|
@@ -201,14 +200,6 @@ def func_vjp_aliased_input(self): |
201 | 200 | self.check_results(ref_result, aliased_result) |
202 | 201 |
|
203 | 202 | def test_all_cases(self): |
204 | | - with _test_eager_guard(): |
205 | | - self.func_vjp_i1o1() |
206 | | - self.func_vjp_i2o1() |
207 | | - self.func_vjp_i2o2() |
208 | | - self.func_vjp_i2o2_omitting_v() |
209 | | - self.func_vjp_nested() |
210 | | - self.func_vjp_aliased_input() |
211 | | - |
212 | 203 | self.func_vjp_i1o1() |
213 | 204 | self.func_vjp_i2o1() |
214 | 205 | self.func_vjp_i2o2() |
@@ -237,17 +228,12 @@ def test_input_single_tensor(self): |
237 | 228 | ), |
238 | 229 | ) |
239 | 230 | class TestVJPException(unittest.TestCase): |
240 | | - def func_vjp(self): |
| 231 | + def test_vjp(self): |
241 | 232 | with self.assertRaises(self.expected_exception): |
242 | 233 | paddle.incubate.autograd.vjp( |
243 | 234 | self.fun, paddle.to_tensor(self.xs), paddle.to_tensor(self.v) |
244 | 235 | ) |
245 | 236 |
|
246 | | - def test_all_cases(self): |
247 | | - with _test_eager_guard(): |
248 | | - self.func_vjp() |
249 | | - self.func_vjp() |
250 | | - |
251 | 237 |
|
252 | 238 | def jac(grad_fn, f, inputs): |
253 | 239 | assert grad_fn in [ |
@@ -324,11 +310,6 @@ def func_jvp_i2o2_omitting_v(self): |
324 | 310 | self.check_results(results_omitting_v, results_with_v) |
325 | 311 |
|
326 | 312 | def test_all_cases(self): |
327 | | - with _test_eager_guard(): |
328 | | - self.func_jvp_i1o1() |
329 | | - self.func_jvp_i2o1() |
330 | | - self.func_jvp_i2o2() |
331 | | - self.func_jvp_i2o2_omitting_v() |
332 | 313 | self.func_jvp_i1o1() |
333 | 314 | self.func_jvp_i2o1() |
334 | 315 | self.func_jvp_i2o2() |
@@ -372,7 +353,7 @@ def setUp(self): |
372 | 353 | .get("atol") |
373 | 354 | ) |
374 | 355 |
|
375 | | - def func_jacobian(self): |
| 356 | + def test_jacobian(self): |
376 | 357 | xs = ( |
377 | 358 | [paddle.to_tensor(x) for x in self.xs] |
378 | 359 | if isinstance(self.xs, typing.Sequence) |
@@ -409,11 +390,6 @@ def _get_expected(self): |
409 | 390 | ) |
410 | 391 | return utils._np_concat_matrix_sequence(jac, utils.MatrixFormat.NM) |
411 | 392 |
|
412 | | - def test_all_cases(self): |
413 | | - with _test_eager_guard(): |
414 | | - self.func_jacobian() |
415 | | - self.func_jacobian() |
416 | | - |
417 | 393 |
|
418 | 394 | @utils.place(config.DEVICES) |
419 | 395 | @utils.parameterize( |
@@ -451,7 +427,7 @@ def setUp(self): |
451 | 427 | .get("atol") |
452 | 428 | ) |
453 | 429 |
|
454 | | - def func_jacobian(self): |
| 430 | + def test_jacobian(self): |
455 | 431 | xs = ( |
456 | 432 | [paddle.to_tensor(x) for x in self.xs] |
457 | 433 | if isinstance(self.xs, typing.Sequence) |
@@ -505,11 +481,6 @@ def _get_expected(self): |
505 | 481 | jac, utils.MatrixFormat.NBM, utils.MatrixFormat.BNM |
506 | 482 | ) |
507 | 483 |
|
508 | | - def test_all_cases(self): |
509 | | - with _test_eager_guard(): |
510 | | - self.func_jacobian() |
511 | | - self.func_jacobian() |
512 | | - |
513 | 484 |
|
514 | 485 | class TestHessianNoBatch(unittest.TestCase): |
515 | 486 | @classmethod |
@@ -607,13 +578,6 @@ def func(x): |
607 | 578 | paddle.incubate.autograd.Hessian(func, paddle.ones([3])) |
608 | 579 |
|
609 | 580 | def test_all_cases(self): |
610 | | - with _test_eager_guard(): |
611 | | - self.setUpClass() |
612 | | - self.func_single_input() |
613 | | - self.func_multi_input() |
614 | | - self.func_allow_unused_true() |
615 | | - self.func_create_graph_true() |
616 | | - self.func_out_not_single() |
617 | 581 | self.setUpClass() |
618 | 582 | self.func_single_input() |
619 | 583 | self.func_multi_input() |
@@ -744,13 +708,6 @@ def func(x): |
744 | 708 | ) |
745 | 709 |
|
746 | 710 | def test_all_cases(self): |
747 | | - with _test_eager_guard(): |
748 | | - self.setUpClass() |
749 | | - self.func_single_input() |
750 | | - self.func_multi_input() |
751 | | - self.func_allow_unused() |
752 | | - self.func_stop_gradient() |
753 | | - self.func_out_not_single() |
754 | 711 | self.setUpClass() |
755 | 712 | self.func_single_input() |
756 | 713 | self.func_multi_input() |
|
0 commit comments