1919from typing_extensions import Literal
2020
2121
22- def reduce (x : Tensor , reduction : Literal ["elementwise_mean" , "sum" , "none" , None ]) -> Tensor :
22+ def reduce (x : Tensor , reduction : Optional [ Literal ["elementwise_mean" , "sum" , "none" ] ]) -> Tensor :
2323 """Reduces a given tensor by a given reduction method.
2424
2525 Args:
@@ -46,7 +46,7 @@ def class_reduce(
4646 num : Tensor ,
4747 denom : Tensor ,
4848 weights : Tensor ,
49- class_reduction : Literal ["micro" , "macro" , "weighted" , "none" , None ] = "none" ,
49+ class_reduction : Optional [ Literal ["micro" , "macro" , "weighted" , "none" ] ] = "none" ,
5050) -> Tensor :
5151 """Reduce classification metrics of the form ``num / denom * weights``.
5252
@@ -147,7 +147,7 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens
147147 torch .distributed .all_gather (gathered_result , result_padded , group )
148148 for idx , item_size in enumerate (local_sizes ):
149149 slice_param = [slice (dim_size ) for dim_size in item_size ]
150- gathered_result [idx ] = gathered_result [idx ][slice_param ]
150+ gathered_result [idx ] = gathered_result [idx ][tuple ( slice_param ) ]
151151 # to propagate autograd graph from local rank
152152 gathered_result [torch .distributed .get_rank (group )] = result
153153 return gathered_result
0 commit comments