@@ -105,7 +105,7 @@ def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
105105 group = _groups [group_name ]()
106106 if group is None :
107107 raise ValueError (f"Group { group_name } is destroyed." )
108- group ._all_reduce (tensor )
108+ group ._all_reduce_in_place (tensor )
109109
110110 @inplace_all_reduce .register_fake
111111 def _ (tensor : torch .Tensor , group_name : str ) -> None :
@@ -118,7 +118,7 @@ def outplace_all_reduce(tensor: torch.Tensor,
118118 group = _groups [group_name ]()
119119 if group is None :
120120 raise ValueError (f"Group { group_name } is destroyed." )
121- return group ._all_reduce (tensor )
121+ return group ._all_reduce_out_place (tensor )
122122
123123 @outplace_all_reduce .register_fake
124124 def _ (tensor : torch .Tensor , group_name : str ) -> torch .Tensor :
@@ -338,40 +338,33 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
338338 return input_
339339
340340 if not supports_custom_op ():
341- return self ._all_reduce (input_ )
341+ self ._all_reduce_in_place (input_ )
342+ return input_
342343
343344 if self .tpu_communicator is not None and \
344345 not self .tpu_communicator .disabled :
345346 # TPU handles Dynamo with its own logic.
346- return self ._all_reduce (input_ )
347+ return self .tpu_communicator . all_reduce (input_ )
347348
348- if self .ca_comm is not None and self .ca_comm .should_custom_ar (input_ ):
349+ if self .ca_comm is not None and \
350+ not self .ca_comm .disabled and \
351+ self .ca_comm .should_custom_ar (input_ ):
349352 return torch .ops .vllm .outplace_all_reduce (
350353 input_ , group_name = self .unique_name )
351354 else :
352355 torch .ops .vllm .inplace_all_reduce (input_ ,
353356 group_name = self .unique_name )
354357 return input_
355358
356- def _all_reduce (self , input_ : torch .Tensor ) -> torch .Tensor :
357- """
358- The actual all-reduce implementation.
359-
360- NOTE: This operation will be applied in-place or out-of-place.
361- Always assume this function modifies its input, but use the return
362- value as the output.
363- """
359+ def _all_reduce_out_place (self , input_ : torch .Tensor ) -> torch .Tensor :
364360 ca_comm = self .ca_comm
361+ assert ca_comm is not None
362+ assert not ca_comm .disabled
363+ out = ca_comm .custom_all_reduce (input_ )
364+ assert out is not None
365+ return out
365366
366- # For TPUs, use TPU communicator.
367- tpu_comm = self .tpu_communicator
368- if tpu_comm is not None and not tpu_comm .disabled :
369- return tpu_comm .all_reduce (input_ )
370-
371- if ca_comm is not None :
372- out = ca_comm .custom_all_reduce (input_ )
373- if out is not None :
374- return out
367+ def _all_reduce_in_place (self , input_ : torch .Tensor ) -> None :
375368 pynccl_comm = self .pynccl_comm
376369 if (pynccl_comm is not None and not pynccl_comm .disabled ):
377370 pynccl_comm .all_reduce (input_ )
@@ -380,7 +373,6 @@ def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
380373 ipex .distributed .all_reduce (input_ , group = self .device_group )
381374 else :
382375 torch .distributed .all_reduce (input_ , group = self .device_group )
383- return input_
384376
385377 def all_gather (self , input_ : torch .Tensor , dim : int = - 1 ) -> torch .Tensor :
386378 world_size = self .world_size
0 commit comments