Skip to content

Commit 52ea730

Browse files
fix: add tuple cast to avoid nd slicing deprecation warning (#3319)
Signed-off-by: Kyle Mylonakis <kyle@protopia.ai> Co-authored-by: @VijayVignesh1 <vijayvigneshp02@gmail.com>
1 parent 631d66f commit 52ea730

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

src/torchmetrics/utilities/distributed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing_extensions import Literal
2020

2121

22-
def reduce(x: Tensor, reduction: Literal["elementwise_mean", "sum", "none", None]) -> Tensor:
22+
def reduce(x: Tensor, reduction: Optional[Literal["elementwise_mean", "sum", "none"]]) -> Tensor:
2323
"""Reduces a given tensor by a given reduction method.
2424
2525
Args:
@@ -46,7 +46,7 @@ def class_reduce(
4646
num: Tensor,
4747
denom: Tensor,
4848
weights: Tensor,
49-
class_reduction: Literal["micro", "macro", "weighted", "none", None] = "none",
49+
class_reduction: Optional[Literal["micro", "macro", "weighted", "none"]] = "none",
5050
) -> Tensor:
5151
"""Reduce classification metrics of the form ``num / denom * weights``.
5252
@@ -147,7 +147,7 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens
147147
torch.distributed.all_gather(gathered_result, result_padded, group)
148148
for idx, item_size in enumerate(local_sizes):
149149
slice_param = [slice(dim_size) for dim_size in item_size]
150-
gathered_result[idx] = gathered_result[idx][slice_param]
150+
gathered_result[idx] = gathered_result[idx][tuple(slice_param)]
151151
# to propagate autograd graph from local rank
152152
gathered_result[torch.distributed.get_rank(group)] = result
153153
return gathered_result

0 commit comments

Comments
 (0)