Skip to content

Commit a9b7e0a

Browse files
Copilotvfdev-5
andauthored
Fix error message typo and remove redundant type casts (#3503)
Addresses code review findings and Pyrefly warnings in metrics module. ## Error Message Correction `MeanPairwiseDistance.compute()` raised `NotComputableError` with incorrect class name reference: ```python # Before raise NotComputableError("MeanAbsoluteError must have at least one example...") # After raise NotComputableError("MeanPairwiseDistance must have at least one example...") ``` Updated corresponding test expectation in `test_mean_pairwise_distance.py::test_zero_sample`. ## Redundant Cast Removal Removed unnecessary `cast()` calls in `Loss.update()` where function signature already provides type narrowing: ```python def update(self, output: tuple[Tensor, Tensor] | tuple[Tensor, Tensor, dict]) -> None: if len(output) == 2: y_pred, y = output # cast() removed - type already constrained else: y_pred, y, kwargs = output # cast() removed - type already constrained ``` Removed unused `cast` import from `typing`. <!-- START COPILOT CODING AGENT TIPS --> --- 💬 We'd love your input! Share your thoughts on Copilot coding agent in our [2 minute survey](https://gh.io/copilot-coding-agent-survey). --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: vfdev-5 <2459423+vfdev-5@users.noreply.github.com>
1 parent ae54b27 commit a9b7e0a

3 files changed

Lines changed: 8 additions & 8 deletions

File tree

ignite/metrics/loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, cast
1+
from typing import Callable
22

33
import torch
44

@@ -92,10 +92,10 @@ def reset(self) -> None:
9292
@reinit__is_reduced
9393
def update(self, output: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, dict]) -> None:
9494
if len(output) == 2:
95-
y_pred, y = cast(tuple[torch.Tensor, torch.Tensor], output)
95+
y_pred, y = output
9696
kwargs: dict = {}
9797
else:
98-
y_pred, y, kwargs = cast(tuple[torch.Tensor, torch.Tensor, dict], output)
98+
y_pred, y, kwargs = output
9999
average_loss = self._loss_fn(y_pred, y, **kwargs).detach()
100100

101101
if len(average_loss.shape) != 0:

ignite/metrics/mean_pairwise_distance.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Sequence, Union
1+
from typing import Callable, Sequence
22

33
import torch
44
from torch.nn.functional import pairwise_distance
@@ -72,7 +72,7 @@ def __init__(
7272
p: int = 2,
7373
eps: float = 1e-6,
7474
output_transform: Callable = lambda x: x,
75-
device: Union[str, torch.device] = torch.device("cpu"),
75+
device: str | torch.device = torch.device("cpu"),
7676
skip_unrolling: bool = False,
7777
) -> None:
7878
super(MeanPairwiseDistance, self).__init__(output_transform, device=device, skip_unrolling=False)
@@ -92,7 +92,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
9292
self._num_examples += y.shape[0]
9393

9494
@sync_all_reduce("_sum_of_distances", "_num_examples")
95-
def compute(self) -> Union[float, torch.Tensor]:
95+
def compute(self) -> float | torch.Tensor:
9696
if self._num_examples == 0:
97-
raise NotComputableError("MeanAbsoluteError must have at least one example before it can be computed.")
97+
raise NotComputableError("MeanPairwiseDistance must have at least one example before it can be computed.")
9898
return self._sum_of_distances.item() / self._num_examples

tests/ignite/metrics/test_mean_pairwise_distance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
def test_zero_sample():
1313
mpd = MeanPairwiseDistance()
1414
with pytest.raises(
15-
NotComputableError, match=r"MeanAbsoluteError must have at least one example before it can be computed"
15+
NotComputableError, match=r"MeanPairwiseDistance must have at least one example before it can be computed"
1616
):
1717
mpd.compute()
1818

0 commit comments

Comments
 (0)