Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 29 additions & 29 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Sequence
from enum import Enum
from types import DynamicClassAttribute
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Iterable, Iterator, TYPE_CHECKING

from torch.utils.data import DataLoader

Expand All @@ -28,7 +28,7 @@ class CallableEventWithFilter:
name: The enum-name of the current object. Only needed for internal use. Do not touch!
"""

def __init__(self, value: str, event_filter: Optional[Callable] = None, name: Optional[str] = None) -> None:
def __init__(self, value: str, event_filter: Callable | None = None, name: str | None = None) -> None:
self.filter = event_filter

if not hasattr(self, "_value_"):
Expand All @@ -50,11 +50,11 @@ def value(self) -> str:

def __call__(
self,
event_filter: Optional[Callable] = None,
every: Optional[int] = None,
once: Optional[Union[int, List]] = None,
before: Optional[int] = None,
after: Optional[int] = None,
event_filter: Callable | None = None,
every: int | None = None,
once: int | list | None = None,
before: int | None = None,
after: int | None = None,
) -> "CallableEventWithFilter":
"""
Makes the event class callable and accepts either an arbitrary callable as filter
Expand Down Expand Up @@ -138,7 +138,7 @@ def wrapper(engine: "Engine", event: int) -> bool:
return wrapper

@staticmethod
def once_event_filter(once: List) -> Callable:
def once_event_filter(once: list) -> Callable:
"""A wrapper for once event filter."""

def wrapper(engine: "Engine", event: int) -> bool:
Expand All @@ -149,9 +149,9 @@ def wrapper(engine: "Engine", event: int) -> bool:
return wrapper

@staticmethod
def before_and_after_event_filter(before: Optional[int] = None, after: Optional[int] = None) -> Callable:
def before_and_after_event_filter(before: int | None = None, after: int | None = None) -> Callable:
"""A wrapper for before and after event filter."""
before_: Union[int, float] = float("inf") if before is None else before
before_: int | float = float("inf") if before is None else before
after_: int = 0 if after is None else after

def wrapper(engine: "Engine", event: int) -> bool:
Expand All @@ -163,10 +163,10 @@ def wrapper(engine: "Engine", event: int) -> bool:

@staticmethod
def every_before_and_after_event_filter(
every: int, before: Optional[int] = None, after: Optional[int] = None
every: int, before: int | None = None, after: int | None = None
) -> Callable:
"""A wrapper which triggers for every `every` iterations after `after` and before `before`."""
before_: Union[int, float] = float("inf") if before is None else before
before_: int | float = float("inf") if before is None else before
after_: int = 0 if after is None else after

def wrapper(engine: "Engine", event: int) -> bool:
Expand Down Expand Up @@ -428,23 +428,23 @@ def call_on_events(engine):
"""

def __init__(self) -> None:
self._events: List[Union[Events, CallableEventWithFilter]] = []
self._events: list[Events | CallableEventWithFilter] = []

def _append(self, event: Union[Events, CallableEventWithFilter]) -> None:
def _append(self, event: Events | CallableEventWithFilter) -> None:
if not isinstance(event, (Events, CallableEventWithFilter)):
raise TypeError(f"Argument event should be Events or CallableEventWithFilter, got: {type(event)}")
self._events.append(event)

def __getitem__(self, item: int) -> Union[Events, CallableEventWithFilter]:
def __getitem__(self, item: int) -> Events | CallableEventWithFilter:
return self._events[item]

def __iter__(self) -> Iterator[Union[Events, CallableEventWithFilter]]:
def __iter__(self) -> Iterator[Events | CallableEventWithFilter]:
return iter(self._events)

def __len__(self) -> int:
return len(self._events)

def __or__(self, other: Union[Events, CallableEventWithFilter]) -> "EventsList":
def __or__(self, other: Events | CallableEventWithFilter) -> "EventsList":
self._append(event=other)
return self

Expand Down Expand Up @@ -472,7 +472,7 @@ class State:
kwargs: keyword arguments to be defined as State attributes.
"""

event_to_attr: Dict[Union[str, "Events", "CallableEventWithFilter"], str] = {
event_to_attr: dict[str | Events | CallableEventWithFilter, str] = {
Events.GET_BATCH_STARTED: "iteration",
Events.GET_BATCH_COMPLETED: "iteration",
Events.ITERATION_STARTED: "iteration",
Expand All @@ -486,15 +486,15 @@ class State:
def __init__(self, **kwargs: Any) -> None:
self.iteration = 0
self.epoch = 0
self.epoch_length: Optional[int] = None
self.max_epochs: Optional[int] = None
self.max_iters: Optional[int] = None
self.output: Optional[int] = None
self.batch: Optional[int] = None
self.metrics: Dict[str, Any] = {}
self.dataloader: Optional[Union[DataLoader, Iterable[Any]]] = None
self.seed: Optional[int] = None
self.times: Dict[str, Optional[float]] = {
self.epoch_length: int | None = None
self.max_epochs: int | None = None
self.max_iters: int | None = None
self.output: int | None = None
self.batch: int | None = None
self.metrics: dict[str, Any] = {}
self.dataloader: DataLoader | Iterable[Any] | None = None
self.seed: int | None = None
self.times: dict[str, float | None] = {
Events.EPOCH_COMPLETED.name: None,
Events.COMPLETED.name: None,
}
Expand All @@ -509,7 +509,7 @@ def _update_attrs(self) -> None:
if not hasattr(self, value):
setattr(self, value, 0)

def get_event_attrib_value(self, event_name: Union[str, Events, CallableEventWithFilter]) -> int:
def get_event_attrib_value(self, event_name: str | Events | CallableEventWithFilter) -> int:
"""Get the value of Event attribute with given `event_name`."""
if event_name not in State.event_to_attr:
raise RuntimeError(f"Unknown event name '{event_name}'")
Expand Down Expand Up @@ -553,7 +553,7 @@ def print_epoch(engine):
"""

def __init__(
self, event_name: Union[CallableEventWithFilter, Enum, EventsList, Events], handler: Callable, engine: "Engine"
self, event_name: CallableEventWithFilter | Enum | EventsList | Events, handler: Callable, engine: "Engine"
) -> None:
self.event_name = event_name
self.handler = weakref.ref(handler)
Expand Down
Loading