Skip to content

Commit ea4ff8e

Browse files
committed
fix splitting ops
1 parent ba9e563 commit ea4ff8e

File tree

3 files changed

+8
-17
lines changed

3 files changed

+8
-17
lines changed

vllm/compilation/nanoflow/manager.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ def __init__(
3838
tag_graph(
3939
self.original_graph_module,
4040
{
41-
"vllm.unified_attention": "attention",
42-
"vllm.unified_attention_with_output": "attention",
4341
"vllm.all_reduce": "all_reduce",
4442
},
4543
)

vllm/config/vllm.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -433,20 +433,12 @@ def __post_init__(self):
433433
"nano batch split. Disabling nano batch split."
434434
)
435435
self.compilation_config.enable_nano_batch_split = False
436-
else:
437-
nano_batch_splitting_ops = [
438-
"vllm.all_reduce",
439-
]
440-
if self.compilation_config.splitting_ops and set(
441-
self.compilation_config.splitting_ops
442-
) != set(nano_batch_splitting_ops):
443-
logger.info(
444-
"splitting_ops is not supported with "
445-
"nano batch split. Disabling nano batch split."
446-
)
447-
self.compilation_config.enable_nano_batch_split = False
448-
else:
449-
self.compilation_config.splitting_ops = nano_batch_splitting_ops
436+
elif self.compilation_config.splitting_ops and \
437+
"vllm.all_reduce" not in self.compilation_config.splitting_ops:
438+
logger.info(
439+
"adding vllm.all_reduce to splitting_ops for nano batch split."
440+
)
441+
self.compilation_config.splitting_ops.append("vllm.all_reduce")
450442

451443
# If the user does not explicitly set a compilation mode, then
452444
# we use the default mode. The default mode depends on other

vllm/v1/worker/gpu_model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4195,7 +4195,8 @@ def initialize_metadata_builders(
41954195
if kv_cache_group_id < len(kernel_block_sizes)
41964196
else None,
41974197
num_metadata_builders=1
4198-
if not self.parallel_config.enable_dbo
4198+
if not self.parallel_config.enable_dbo and
4199+
not self.compilation_config.enable_nano_batch_split
41994200
else 2,
42004201
)
42014202
# Calculate reorder batch threshold (if needed)

0 commit comments

Comments
 (0)