Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 92fa8db

Browse files
WoosukKwonalexm-redhat
authored andcommitted
Disable custom all reduce by default (vllm-project#2808)
1 parent 23a9ded commit 92fa8db

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

vllm/config.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -423,16 +423,26 @@ def _verify_args(self) -> None:
423423
if self.pipeline_parallel_size > 1:
424424
raise NotImplementedError(
425425
"Pipeline parallelism is not supported yet.")
426-
if is_hip():
426+
if not self.disable_custom_all_reduce and self.world_size > 1:
427+
if is_hip():
428+
self.disable_custom_all_reduce = True
429+
logger.info(
430+
"Disabled the custom all-reduce kernel because it is not "
431+
"supported on AMD GPUs.")
432+
elif self.pipeline_parallel_size > 1:
433+
self.disable_custom_all_reduce = True
434+
logger.info(
435+
"Disabled the custom all-reduce kernel because it is not "
436+
"supported with pipeline parallelism.")
437+
438+
# FIXME(woosuk): Fix the stability issues and re-enable the custom
439+
# all-reduce kernel.
440+
if not self.disable_custom_all_reduce and self.world_size > 1:
427441
self.disable_custom_all_reduce = True
428442
logger.info(
429-
"Disabled the custom all-reduce kernel because it is not "
430-
"supported on AMD GPUs.")
431-
elif self.pipeline_parallel_size > 1:
432-
self.disable_custom_all_reduce = True
433-
logger.info(
434-
"Disabled the custom all-reduce kernel because it is not "
435-
"supported with pipeline parallelism.")
443+
"Custom all-reduce kernels are temporarily disabled due to "
444+
"stability issues. We will re-enable them once the issues are "
445+
"resolved.")
436446

437447

438448
class SchedulerConfig:

0 commit comments

Comments
 (0)