Skip to content

Commit 08f0a99

Browse files
leo-ponyBofeng BF1 Xue
authored andcommitted
[Ascend]: Fixed the issue where OOT Platform vllm-ascend could not enable SP in Eager mode (vllm-project#28935)
Signed-off-by: leo-pony <[email protected]> Signed-off-by: Bofeng BF1 Xue <[email protected]>
1 parent 97bf422 commit 08f0a99

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

vllm/config/compilation.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,13 @@ def post_init_cudagraph_sizes(self) -> None:
855855
self.compute_bs_to_padded_graph_size()
856856

857857
def set_splitting_ops_for_v1(self):
858+
# To compatible with OOT hardware plugin platform (for example vllm-ascend)
859+
# which currently only supports sequence parallelism in eager mode.
860+
if self.mode != CompilationMode.VLLM_COMPILE:
861+
if self.splitting_ops is None:
862+
self.splitting_ops = []
863+
return
864+
858865
# NOTE: this function needs to be called only when mode is
859866
# CompilationMode.VLLM_COMPILE
860867
assert self.mode == CompilationMode.VLLM_COMPILE, (

vllm/config/vllm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -797,15 +797,21 @@ def has_blocked_weights():
797797
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
798798

799799
# Do this after all the updates to compilation_config.mode
800-
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
801-
self.compilation_config.set_splitting_ops_for_v1()
800+
self.compilation_config.set_splitting_ops_for_v1()
802801

803802
if self.compilation_config.pass_config.enable_sequence_parallelism:
804803
# With pipeline parallelism or dynamo partitioning,
805804
# native rms norm tracing errors due to incorrect residual shape.
806805
# Use custom rms norm to unblock. In the future,
807806
# the pass will operate on higher-level IR to avoid the issue.
808807
# TODO: https://github.com/vllm-project/vllm/issues/27894
808+
if self.compilation_config.mode != CompilationMode.VLLM_COMPILE:
809+
logger.warning(
810+
"Sequence parallelism is enabled, but running in wrong "
811+
"vllm compile mode: %s.",
812+
self.compilation_config.mode,
813+
)
814+
809815
is_fullgraph = (
810816
self.compilation_config.use_inductor_graph_partition
811817
or len(self.compilation_config.splitting_ops) == 0

0 commit comments

Comments
 (0)