diff --git a/ignite/handlers/base_logger.py b/ignite/handlers/base_logger.py index dbb2997935e2..026c15072f67 100644 --- a/ignite/handlers/base_logger.py +++ b/ignite/handlers/base_logger.py @@ -5,7 +5,7 @@ import warnings from abc import ABCMeta, abstractmethod from collections import OrderedDict -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Sequence import torch import torch.nn as nn @@ -19,7 +19,7 @@ class BaseHandler(metaclass=ABCMeta): """Base handler for defining various useful handlers.""" @abstractmethod - def __call__(self, engine: Engine, logger: Any, event_name: Union[str, Events]) -> None: + def __call__(self, engine: Engine, logger: Any, event_name: str | Events) -> None: pass @@ -31,8 +31,8 @@ class BaseWeightsHandler(BaseHandler): def __init__( self, model: nn.Module, - tag: Optional[str] = None, - whitelist: Optional[Union[List[str], Callable[[str, nn.Parameter], bool]]] = None, + tag: str | None = None, + whitelist: list[str] | Callable[[str, nn.Parameter], bool] | None = None, ): if not isinstance(model, torch.nn.Module): raise TypeError(f"Argument model should be of type torch.nn.Module, but given {type(model)}") @@ -61,7 +61,7 @@ class BaseOptimizerParamsHandler(BaseHandler): Base handler for logging optimizer parameters """ - def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None): + def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: str | None = None): if not ( isinstance(optimizer, Optimizer) or (hasattr(optimizer, "param_groups") and isinstance(optimizer.param_groups, Sequence)) @@ -84,10 +84,10 @@ class BaseOutputHandler(BaseHandler): def __init__( self, tag: str, - metric_names: Optional[Union[str, List[str]]] = None, - output_transform: Optional[Callable] = None, - global_step_transform: Optional[Callable[[Engine, Union[str, Events]], int]] = None, - state_attributes: Optional[List[str]] = None, + metric_names: str | list[str] | None = None, + output_transform: Callable | None = None, + global_step_transform: Callable[[Engine, str | Events], int] | None = None, + state_attributes: list[str] | None = None, ): if metric_names is not None: if not (isinstance(metric_names, list) or (isinstance(metric_names, str) and metric_names == "all")): @@ -106,7 +106,7 @@ def __init__( if global_step_transform is None: - def global_step_transform(engine: Engine, event_name: Union[str, Events]) -> int: + def global_step_transform(engine: Engine, event_name: str | Events) -> int: return engine.state.get_event_attrib_value(event_name) self.tag = tag @@ -116,8 +116,8 @@ def global_step_transform(engine: Engine, event_name: Union[str, Events]) -> int self.state_attributes = state_attributes def _setup_output_metrics_state_attrs( - self, engine: Engine, log_text: Optional[bool] = False, key_tuple: Optional[bool] = True - ) -> Dict[Any, Any]: + self, engine: Engine, log_text: bool | None = False, key_tuple: bool | None = True + ) -> dict[Any, Any]: """Helper method to setup metrics and state attributes to log""" metrics_state_attrs = OrderedDict() if self.metric_names is not None: @@ -144,10 +144,12 @@ def _setup_output_metrics_state_attrs( if self.state_attributes is not None: metrics_state_attrs.update({name: getattr(engine.state, name, None) for name in self.state_attributes}) - metrics_state_attrs_dict: Dict[Any, Union[str, float, numbers.Number]] = OrderedDict() + metrics_state_attrs_dict: dict[Any, str | float | numbers.Number] = OrderedDict() - def key_tuple_fn(parent_key: Union[str, Tuple[str, ...]], *args: str) -> Tuple[str, ...]: - if parent_key is None or isinstance(parent_key, str): + def key_tuple_fn(parent_key: str | tuple[str, ...] | None, *args: str) -> tuple[str, ...]: + if parent_key is None: + return args + if isinstance(parent_key, str): return (parent_key,) + args return parent_key + args @@ -158,8 +160,8 @@ def key_str_fn(parent_key: str, *args: str) -> str: key_fn = key_tuple_fn if key_tuple else key_str_fn def handle_value_fn( - value: Union[str, int, float, numbers.Number, torch.Tensor] - ) -> Union[None, str, float, numbers.Number]: + value: str | int | float | numbers.Number | torch.Tensor, + ) -> None | str | float | numbers.Number: if isinstance(value, numbers.Number): return value elif isinstance(value, torch.Tensor) and value.ndimension() == 0: @@ -179,8 +181,8 @@ def _flatten_dict( in_dict: collections.Mapping, key_fn: Callable, value_fn: Callable, - parent_key: Optional[Union[str, Tuple[str, ...]]] = None, -) -> Dict: + parent_key: str | tuple[str, ...] | None = None, +) -> dict: items = {} for key, value in in_dict.items(): new_key = key_fn(parent_key, key) @@ -212,9 +214,9 @@ class BaseWeightsScalarHandler(BaseWeightsHandler): def __init__( self, model: nn.Module, - reduction: Callable[[torch.Tensor], Union[float, torch.Tensor]] = torch.norm, - tag: Optional[str] = None, - whitelist: Optional[Union[List[str], Callable[[str, nn.Parameter], bool]]] = None, + reduction: Callable[[torch.Tensor], float | torch.Tensor] = torch.norm, + tag: str | None = None, + whitelist: list[str] | Callable[[str, nn.Parameter], bool] | None = None, ): super(BaseWeightsScalarHandler, self).__init__(model, tag=tag, whitelist=whitelist) @@ -242,7 +244,7 @@ def attach( self, engine: Engine, log_handler: Callable, - event_name: Union[str, Events, CallableEventWithFilter, EventsList], + event_name: str | Events | CallableEventWithFilter | EventsList, *args: Any, **kwargs: Any, ) -> RemovableEventHandle: