@@ -213,11 +213,13 @@ def all_reduce(
213213
214214 return cast (Union [torch .Tensor , float ], self ._collective_op (tensor , self ._do_all_reduce , op , group = group ))
215215
216- def all_gather (self , tensor : Union [torch .Tensor , float , str ]) -> Union [torch .Tensor , float , List [float ], List [str ]]:
216+ def all_gather (
217+ self , tensor : Union [torch .Tensor , float , str ], group : Optional [Any ] = None
218+ ) -> Union [torch .Tensor , float , List [float ], List [str ]]:
217219 if not isinstance (tensor , (torch .Tensor , Number , str )):
218220 raise TypeError (f"Unhandled input type { type (tensor )} " )
219221
220- return self ._collective_op (tensor , self ._do_all_gather )
222+ return self ._collective_op (tensor , self ._do_all_gather , group = group )
221223
222224 def new_group (self , ranks : List [int ], ** kwargs : Any ) -> Any :
223225 if isinstance (ranks , list ) and all (isinstance (item , int ) for item in ranks ):
@@ -275,7 +277,7 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
275277 pass
276278
277279 @abstractmethod
278- def _do_all_gather (self , tensor : torch .Tensor ) -> torch .Tensor :
280+ def _do_all_gather (self , tensor : torch .Tensor , group : Optional [ Any ] = None ) -> torch .Tensor :
279281 pass
280282
281283 @abstractmethod
@@ -349,7 +351,9 @@ def all_reduce(
349351 ) -> Union [torch .Tensor , float ]:
350352 return tensor
351353
352- def all_gather (self , tensor : Union [torch .Tensor , float , str ]) -> Union [torch .Tensor , float , List [float ], List [str ]]:
354+ def all_gather (
355+ self , tensor : Union [torch .Tensor , float , str ], group : Optional [Any ] = None
356+ ) -> Union [torch .Tensor , float , List [float ], List [str ]]:
353357 if isinstance (tensor , torch .Tensor ):
354358 return tensor
355359 return cast (Union [List [float ], List [str ]], [tensor ])
@@ -364,7 +368,7 @@ def broadcast(
364368 def _do_all_reduce (self , tensor : torch .Tensor , op : str = "SUM" , group : Optional [Any ] = None ) -> torch .Tensor :
365369 return tensor
366370
367- def _do_all_gather (self , tensor : torch .Tensor ) -> torch .Tensor :
371+ def _do_all_gather (self , tensor : torch .Tensor , group : Optional [ Any ] = None ) -> torch .Tensor :
368372 return tensor
369373
370374 def _do_new_group (self , ranks : List [int ], ** kwargs : Any ) -> Any :
0 commit comments