|
13 | 13 |
|
14 | 14 | import random |
15 | 15 | from enum import Enum |
| 16 | +from typing import TYPE_CHECKING |
16 | 17 |
|
| 18 | +from monai.config import IgniteInfo |
17 | 19 | from monai.utils import deprecated |
| 20 | +from monai.utils.module import min_version, optional_import |
18 | 21 |
|
19 | 22 | __all__ = [ |
20 | 23 | "StrEnum", |
@@ -88,6 +91,14 @@ def __repr__(self): |
88 | 91 | return self.value |
89 | 92 |
|
90 | 93 |
|
| 94 | +if TYPE_CHECKING: |
| 95 | + from ignite.engine import EventEnum |
| 96 | +else: |
| 97 | + EventEnum, _ = optional_import( |
| 98 | + "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum", as_type="base" |
| 99 | + ) |
| 100 | + |
| 101 | + |
91 | 102 | class NumpyPadMode(StrEnum): |
92 | 103 | """ |
93 | 104 | See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html |
@@ -692,3 +703,39 @@ class AlgoKeys(StrEnum): |
692 | 703 | ALGO = "algo_instance" |
693 | 704 | IS_TRAINED = "is_trained" |
694 | 705 | SCORE = "best_metric" |
| 706 | + |
| 707 | + |
| 708 | +class AdversarialKeys(StrEnum): |
| 709 | + REALS = "reals" |
| 710 | + REAL_LOGITS = "real_logits" |
| 711 | + FAKES = "fakes" |
| 712 | + FAKE_LOGITS = "fake_logits" |
| 713 | + RECONSTRUCTION_LOSS = "reconstruction_loss" |
| 714 | + GENERATOR_LOSS = "generator_loss" |
| 715 | + DISCRIMINATOR_LOSS = "discriminator_loss" |
| 716 | + |
| 717 | + |
| 718 | +class AdversarialIterationEvents(EventEnum): |
| 719 | + RECONSTRUCTION_LOSS_COMPLETED = "reconstruction_loss_completed" |
| 720 | + GENERATOR_FORWARD_COMPLETED = "generator_forward_completed" |
| 721 | + GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED = "generator_discriminator_forward_completed" |
| 722 | + GENERATOR_LOSS_COMPLETED = "generator_loss_completed" |
| 723 | + GENERATOR_BACKWARD_COMPLETED = "generator_backward_completed" |
| 724 | + GENERATOR_MODEL_COMPLETED = "generator_model_completed" |
| 725 | + DISCRIMINATOR_REALS_FORWARD_COMPLETED = "discriminator_reals_forward_completed" |
| 726 | + DISCRIMINATOR_FAKES_FORWARD_COMPLETED = "discriminator_fakes_forward_completed" |
| 727 | + DISCRIMINATOR_LOSS_COMPLETED = "discriminator_loss_completed" |
| 728 | + DISCRIMINATOR_BACKWARD_COMPLETED = "discriminator_backward_completed" |
| 729 | + DISCRIMINATOR_MODEL_COMPLETED = "discriminator_model_completed" |
| 730 | + |
| 731 | + |
| 732 | +class OrderingType(StrEnum): |
| 733 | + RASTER_SCAN = "raster_scan" |
| 734 | + S_CURVE = "s_curve" |
| 735 | + RANDOM = "random" |
| 736 | + |
| 737 | + |
| 738 | +class OrderingTransformations(StrEnum): |
| 739 | + ROTATE_90 = "rotate_90" |
| 740 | + TRANSPOSE = "transpose" |
| 741 | + REFLECT = "reflect" |
0 commit comments