Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ coverage.xml
.venv/
.python-version
conda_build/

venv/
.neptune/
pytest.ini

Expand Down
50 changes: 25 additions & 25 deletions ignite/handlers/base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable

import torch
import torch.nn as nn
Expand All @@ -19,7 +19,7 @@
"""Base handler for defining various useful handlers."""

@abstractmethod
def __call__(self, engine: Engine, logger: Any, event_name: Union[str, Events]) -> None:
def __call__(self, engine: Engine, logger: Any, event_name: str | Events) -> None:
pass


Expand All @@ -31,8 +31,8 @@
def __init__(
self,
model: nn.Module,
tag: Optional[str] = None,
whitelist: Optional[Union[List[str], Callable[[str, nn.Parameter], bool]]] = None,
tag: str | None = None,
whitelist: list[str] | Callable[[str, nn.Parameter], bool] | None = None,
):
if not isinstance(model, torch.nn.Module):
raise TypeError(f"Argument model should be of type torch.nn.Module, but given {type(model)}")
Expand Down Expand Up @@ -61,14 +61,14 @@
Base handler for logging optimizer parameters
"""

def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None):
def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: str | None = None):
if not (
isinstance(optimizer, Optimizer)
or (hasattr(optimizer, "param_groups") and isinstance(optimizer.param_groups, Sequence))
or (hasattr(optimizer, "param_groups") and isinstance(optimizer.param_groups, collections.Sequence))
):
raise TypeError(
"Argument optimizer should be torch.optim.Optimizer or has attribute 'param_groups' as list/tuple, "

Check failure on line 70 in ignite/handlers/base_logger.py

View workflow job for this annotation

GitHub Actions / pyrefly (3.13, pytorch)

Pyrefly not-callable

Expected a callable, got `Literal['Argument optimizer should be torch.optim.Optimizer or has attribute \'param_groups\' as list/tuple, ']`

Check failure on line 70 in ignite/handlers/base_logger.py

View workflow job for this annotation

GitHub Actions / pyrefly (3.10, pytorch)

Pyrefly not-callable

Expected a callable, got `Literal['Argument optimizer should be torch.optim.Optimizer or has attribute \'param_groups\' as list/tuple, ']`
f"but given {type(optimizer)}"
(f"but given {type(optimizer)}")
)

self.optimizer = optimizer
Expand All @@ -84,10 +84,10 @@
def __init__(
self,
tag: str,
metric_names: Optional[Union[str, List[str]]] = None,
output_transform: Optional[Callable] = None,
global_step_transform: Optional[Callable[[Engine, Union[str, Events]], int]] = None,
state_attributes: Optional[List[str]] = None,
metric_names: str | list[str] | None = None,
output_transform: Callable | None = None,
global_step_transform: Callable[[Engine, str | Events], int] | None = None,
state_attributes: list[str] | None = None,
):
if metric_names is not None:
if not (isinstance(metric_names, list) or (isinstance(metric_names, str) and metric_names == "all")):
Expand All @@ -106,7 +106,7 @@

if global_step_transform is None:

def global_step_transform(engine: Engine, event_name: Union[str, Events]) -> int:
def global_step_transform(engine: Engine, event_name: str | Events) -> int:
return engine.state.get_event_attrib_value(event_name)

self.tag = tag
Expand All @@ -116,8 +116,8 @@
self.state_attributes = state_attributes

def _setup_output_metrics_state_attrs(
self, engine: Engine, log_text: Optional[bool] = False, key_tuple: Optional[bool] = True
) -> Dict[Any, Any]:
self, engine: Engine, log_text: bool | None = False, key_tuple: bool | None = True
) -> dict[Any, Any]:
"""Helper method to setup metrics and state attributes to log"""
metrics_state_attrs = OrderedDict()
if self.metric_names is not None:
Expand All @@ -144,11 +144,11 @@
if self.state_attributes is not None:
metrics_state_attrs.update({name: getattr(engine.state, name, None) for name in self.state_attributes})

metrics_state_attrs_dict: Dict[Any, Union[str, float, numbers.Number]] = OrderedDict()
metrics_state_attrs_dict: dict[Any, str | float | numbers.Number] = OrderedDict()

def key_tuple_fn(parent_key: Union[str, Tuple[str, ...]], *args: str) -> Tuple[str, ...]:
def key_tuple_fn(parent_key: str | tuple[str, ...] | None, *args: str) -> tuple[str, ...]:
if parent_key is None or isinstance(parent_key, str):
return (parent_key,) + args

Check failure on line 151 in ignite/handlers/base_logger.py

View workflow job for this annotation

GitHub Actions / pyrefly (3.13, pytorch)

Pyrefly bad-return

Returned type `tuple[str | None, *tuple[str, ...]]` is not assignable to declared return type `tuple[str, ...]`

Check failure on line 151 in ignite/handlers/base_logger.py

View workflow job for this annotation

GitHub Actions / pyrefly (3.10, pytorch)

Pyrefly bad-return

Returned type `tuple[str | None, *tuple[str, ...]]` is not assignable to declared return type `tuple[str, ...]`
return parent_key + args

def key_str_fn(parent_key: str, *args: str) -> str:
Expand All @@ -158,8 +158,8 @@
key_fn = key_tuple_fn if key_tuple else key_str_fn

def handle_value_fn(
value: Union[str, int, float, numbers.Number, torch.Tensor]
) -> Union[None, str, float, numbers.Number]:
value: str | int | float | numbers.Number | torch.Tensor
) -> None | str | float | numbers.Number:
if isinstance(value, numbers.Number):
return value
elif isinstance(value, torch.Tensor) and value.ndimension() == 0:
Expand All @@ -179,8 +179,8 @@
in_dict: collections.Mapping,
key_fn: Callable,
value_fn: Callable,
parent_key: Optional[Union[str, Tuple[str, ...]]] = None,
) -> Dict:
parent_key: str | tuple[str, ...] | None = None,
) -> dict:
items = {}
for key, value in in_dict.items():
new_key = key_fn(parent_key, key)
Expand Down Expand Up @@ -212,9 +212,9 @@
def __init__(
self,
model: nn.Module,
reduction: Callable[[torch.Tensor], Union[float, torch.Tensor]] = torch.norm,
tag: Optional[str] = None,
whitelist: Optional[Union[List[str], Callable[[str, nn.Parameter], bool]]] = None,
reduction: Callable[[torch.Tensor], float | torch.Tensor] = torch.norm,
tag: str | None = None,
whitelist: list[str] | Callable[[str, nn.Parameter], bool] | None = None,
):
super(BaseWeightsScalarHandler, self).__init__(model, tag=tag, whitelist=whitelist)

Expand Down Expand Up @@ -242,7 +242,7 @@
self,
engine: Engine,
log_handler: Callable,
event_name: Union[str, Events, CallableEventWithFilter, EventsList],
event_name: str | Events | CallableEventWithFilter | EventsList,
*args: Any,
**kwargs: Any,
) -> RemovableEventHandle:
Expand Down Expand Up @@ -326,4 +326,4 @@
self.close()

def close(self) -> None:
pass
pass
Loading