This repository was archived by the owner on Oct 11, 2024. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +18
-8
lines changed
Expand file tree Collapse file tree 1 file changed +18
-8
lines changed Original file line number Diff line number Diff line change @@ -423,16 +423,26 @@ def _verify_args(self) -> None:
423423 if self .pipeline_parallel_size > 1 :
424424 raise NotImplementedError (
425425 "Pipeline parallelism is not supported yet." )
426- if is_hip ():
426+ if not self .disable_custom_all_reduce and self .world_size > 1 :
427+ if is_hip ():
428+ self .disable_custom_all_reduce = True
429+ logger .info (
430+ "Disabled the custom all-reduce kernel because it is not "
431+ "supported on AMD GPUs." )
432+ elif self .pipeline_parallel_size > 1 :
433+ self .disable_custom_all_reduce = True
434+ logger .info (
435+ "Disabled the custom all-reduce kernel because it is not "
436+ "supported with pipeline parallelism." )
437+
438+ # FIXME(woosuk): Fix the stability issues and re-enable the custom
439+ # all-reduce kernel.
440+ if not self .disable_custom_all_reduce and self .world_size > 1 :
427441 self .disable_custom_all_reduce = True
428442 logger .info (
429- "Disabled the custom all-reduce kernel because it is not "
430- "supported on AMD GPUs." )
431- elif self .pipeline_parallel_size > 1 :
432- self .disable_custom_all_reduce = True
433- logger .info (
434- "Disabled the custom all-reduce kernel because it is not "
435- "supported with pipeline parallelism." )
443+ "Custom all-reduce kernels are temporarily disabled due to "
444+ "stability issues. We will re-enable them once the issues are "
445+ "resolved." )
436446
437447
438448class SchedulerConfig :
You can’t perform that action at this time.
0 commit comments