Skip to content

Commit 7aaeab0

Browse files
committed
Adds new enums
1 parent 8a70678 commit 7aaeab0

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

monai/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
1818
from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather
1919
from .enums import (
20+
AdversarialIterationEvents,
21+
AdversarialKeys,
2022
AlgoKeys,
2123
Average,
2224
BlendMode,
@@ -46,6 +48,8 @@
4648
MetricReduction,
4749
NdimageMode,
4850
NumpyPadMode,
51+
OrderingTransformations,
52+
OrderingType,
4953
PatchKeys,
5054
PostFix,
5155
ProbMapKeys,

monai/utils/enums.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313

1414
import random
1515
from enum import Enum
16+
from typing import TYPE_CHECKING
1617

18+
from monai.config import IgniteInfo
1719
from monai.utils import deprecated
20+
from monai.utils.module import min_version, optional_import
1821

1922
__all__ = [
2023
"StrEnum",
@@ -88,6 +91,14 @@ def __repr__(self):
8891
return self.value
8992

9093

94+
if TYPE_CHECKING:
95+
from ignite.engine import EventEnum
96+
else:
97+
EventEnum, _ = optional_import(
98+
"ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum", as_type="base"
99+
)
100+
101+
91102
class NumpyPadMode(StrEnum):
92103
"""
93104
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
@@ -692,3 +703,39 @@ class AlgoKeys(StrEnum):
692703
ALGO = "algo_instance"
693704
IS_TRAINED = "is_trained"
694705
SCORE = "best_metric"
706+
707+
708+
class AdversarialKeys(StrEnum):
709+
REALS = "reals"
710+
REAL_LOGITS = "real_logits"
711+
FAKES = "fakes"
712+
FAKE_LOGITS = "fake_logits"
713+
RECONSTRUCTION_LOSS = "reconstruction_loss"
714+
GENERATOR_LOSS = "generator_loss"
715+
DISCRIMINATOR_LOSS = "discriminator_loss"
716+
717+
718+
class AdversarialIterationEvents(EventEnum):
719+
RECONSTRUCTION_LOSS_COMPLETED = "reconstruction_loss_completed"
720+
GENERATOR_FORWARD_COMPLETED = "generator_forward_completed"
721+
GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED = "generator_discriminator_forward_completed"
722+
GENERATOR_LOSS_COMPLETED = "generator_loss_completed"
723+
GENERATOR_BACKWARD_COMPLETED = "generator_backward_completed"
724+
GENERATOR_MODEL_COMPLETED = "generator_model_completed"
725+
DISCRIMINATOR_REALS_FORWARD_COMPLETED = "discriminator_reals_forward_completed"
726+
DISCRIMINATOR_FAKES_FORWARD_COMPLETED = "discriminator_fakes_forward_completed"
727+
DISCRIMINATOR_LOSS_COMPLETED = "discriminator_loss_completed"
728+
DISCRIMINATOR_BACKWARD_COMPLETED = "discriminator_backward_completed"
729+
DISCRIMINATOR_MODEL_COMPLETED = "discriminator_model_completed"
730+
731+
732+
class OrderingType(StrEnum):
733+
RASTER_SCAN = "raster_scan"
734+
S_CURVE = "s_curve"
735+
RANDOM = "random"
736+
737+
738+
class OrderingTransformations(StrEnum):
739+
ROTATE_90 = "rotate_90"
740+
TRANSPOSE = "transpose"
741+
REFLECT = "reflect"

0 commit comments

Comments
 (0)