@@ -212,11 +212,13 @@ def all_reduce(self, tensor: Union[torch.Tensor, float], op: str = "sum") -> Uni
212212
213213 return cast (Union [torch .Tensor , float ], self ._collective_op (tensor , self ._do_all_reduce , op ))
214214
215- def all_gather (self , tensor : Union [torch .Tensor , float , str ]) -> Union [torch .Tensor , float , List [float ], List [str ]]:
215+ def all_gather (
216+ self , tensor : Union [torch .Tensor , float , str ], group : Optional [Union [Any , List [int ]]] = None
217+ ) -> Union [torch .Tensor , float , List [float ], List [str ]]:
216218 if not isinstance (tensor , (torch .Tensor , Number , str )):
217219 raise TypeError (f"Unhandled input type { type (tensor )} " )
218220
219- return self ._collective_op (tensor , self ._do_all_gather )
221+ return self ._collective_op (tensor , self ._do_all_gather , group = group )
220222
221223 def broadcast (
222224 self , tensor : Union [torch .Tensor , float , str , None ], src : int = 0 , safe_mode : bool = False
@@ -268,7 +270,7 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM") -> torch.Tensor:
268270 pass
269271
270272 @abstractmethod
271- def _do_all_gather (self , tensor : torch .Tensor ) -> torch .Tensor :
273+ def _do_all_gather (self , tensor : torch .Tensor , group : Optional [ Union [ Any , List [ int ]]] = None ) -> torch .Tensor :
272274 pass
273275
274276 @abstractmethod
@@ -336,7 +338,9 @@ def spawn(*args: Any, **kwargs: Any) -> None:
336338 def all_reduce (self , tensor : Union [torch .Tensor , float ], op : str = "SUM" ) -> Union [torch .Tensor , float ]:
337339 return tensor
338340
339- def all_gather (self , tensor : Union [torch .Tensor , float , str ]) -> Union [torch .Tensor , float , List [float ], List [str ]]:
341+ def all_gather (
342+ self , tensor : Union [torch .Tensor , float , str ], group : Optional [Union [Any , List [int ]]] = None
343+ ) -> Union [torch .Tensor , float , List [float ], List [str ]]:
340344 if isinstance (tensor , torch .Tensor ):
341345 return tensor
342346 return cast (Union [List [float ], List [str ]], [tensor ])
@@ -351,7 +355,7 @@ def broadcast(
351355 def _do_all_reduce (self , tensor : torch .Tensor , op : str = "SUM" ) -> torch .Tensor :
352356 return tensor
353357
354- def _do_all_gather (self , tensor : torch .Tensor ) -> torch .Tensor :
358+ def _do_all_gather (self , tensor : torch .Tensor , group : Optional [ Union [ Any , List [ int ]]] = None ) -> torch .Tensor :
355359 return tensor
356360
357361 def _do_broadcast (self , tensor : torch .Tensor , src : int ) -> torch .Tensor :
0 commit comments