Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 45 additions & 67 deletions vllm/config/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,14 @@
import hashlib
from collections.abc import Callable
from dataclasses import InitVar
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast

from pydantic import Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self

from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils import (
DEFAULT_MAX_NUM_BATCHED_TOKENS,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
from vllm.utils.import_utils import resolve_obj_by_qualname

if TYPE_CHECKING:
Expand All @@ -33,25 +28,32 @@
class SchedulerConfig:
"""Scheduler configuration."""

DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048
DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128

runner_type: RunnerType = "generate"
"""The runner type to launch for the model."""

max_num_batched_tokens: int = Field(default=None, ge=1)
max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1)
"""Maximum number of tokens to be processed in a single iteration.

This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
The default value here is mainly for convenience when testing.
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""

max_num_seqs: int = Field(default=None, ge=1)
max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1)
"""Maximum number of sequences to be processed in a single iteration.

This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
The default value here is mainly for convenience when testing.
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""

max_model_len: int = Field(default=None, ge=1)
"""Maximum length of a sequence (including prompt and generated text). This
is primarily set in `ModelConfig` and that value should be manually
duplicated here."""
max_model_len: int = Field(default=8192, ge=1)
"""Maximum length of a sequence (including prompt and generated text).

The default value here is mainly for convenience when testing.
In real usage, this should duplicate `ModelConfig.max_model_len` via
`EngineArgs`."""
Comment on lines +51 to +56
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we remove this entirely?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used in some other places like vllm.v1.core.sched.Scheduler. We can try to refactor this in another PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this would need a small refactor. A follow up PR sounds good.


max_num_partial_prefills: int = Field(default=1, ge=1)
"""For chunked prefill, the maximum number of sequences that can be
Expand All @@ -76,9 +78,13 @@ class SchedulerConfig:
NOTE: This will be replaced by speculative config in the future; it is
present to enable correctness tests until then."""

enable_chunked_prefill: bool = Field(default=None)
enable_chunked_prefill: bool = False
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
on the remaining `max_num_batched_tokens`.

The default value here is mainly for convenience when testing.
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""

is_multimodal_model: bool = False
"""True if the model is multimodal."""
Expand Down Expand Up @@ -111,9 +117,6 @@ class SchedulerConfig:
- "priority" means requests are handled based on given priority (lower
value means earlier handling) and time of arrival deciding any ties)."""

chunked_prefill_enabled: bool = Field(init=False)
"""True if chunked prefill is enabled."""

disable_chunked_mm_input: bool = False
"""If set to true and chunked prefill is enabled, we do not want to
partially schedule a multimodal item. Only used in V1
Expand Down Expand Up @@ -182,15 +185,7 @@ def compute_hash(self) -> str:
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str

@field_validator(
"max_num_batched_tokens",
"max_num_seqs",
"max_model_len",
"enable_chunked_prefill",
"scheduler_cls",
"async_scheduling",
mode="wrap",
)
@field_validator("scheduler_cls", "async_scheduling", mode="wrap")
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialisation is delayed."""
Expand All @@ -199,54 +194,30 @@ def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
return handler(value)

def __post_init__(self, is_encoder_decoder: bool) -> None:
if self.max_model_len is None:
self.max_model_len = 8192

if self.max_num_seqs is None:
self.max_num_seqs = 128

if is_encoder_decoder:
# Chunked prefill should be disabled for encoder-decoder models.
self.disable_chunked_mm_input = True
self.chunked_prefill_enabled = False
self.enable_chunked_prefill = False
self.long_prefill_token_threshold = 0
logger.info(
"Encoder-decoder models do not support chunked prefill nor"
" prefix caching; disabling both."
)

if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill:
self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS
else:
# If max_model_len is too short, use
# DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
# for higher throughput.
self.max_num_batched_tokens = max(
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS
)

if self.runner_type == "pooling":
# Choose specific value for higher throughput
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
if self.is_multimodal_model:
# The value needs to be at least the number of multimodal tokens
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
)

# When using default settings,
# Ensure max_num_batched_tokens does not exceed model limit.
# Some models (e.g., Whisper) have embeddings tied to max length.
self.max_num_batched_tokens = min(
self.max_num_seqs * self.max_model_len, self.max_num_batched_tokens
if not self.enable_chunked_prefill:
# If max_model_len is too short, use the default for higher throughput.
self.max_num_batched_tokens = max(
self.max_model_len,
self.max_num_batched_tokens,
)

# Ensure max_num_batched_tokens does not exceed model limit.
# Some models (e.g., Whisper) have embeddings tied to max length.
self.max_num_batched_tokens = min(
self.max_num_seqs * self.max_model_len,
self.max_num_batched_tokens,
)

self.max_num_encoder_input_tokens = self.max_num_batched_tokens
self.encoder_cache_size = self.max_num_batched_tokens

Comment on lines 202 to 215
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This refactoring simplifies the initialization, but it seems to have removed the special default logic for max_num_batched_tokens for pooling and multimodal models.

Previously, if max_num_batched_tokens was not set by the user or a UsageContext-specific default, there was fallback logic to increase it for pooling models (to 32768) and multimodal models (to 5120) for better throughput. This logic was triggered if max_num_batched_tokens was None when __post_init__ was called.

This logic has now been removed. The justification in the PR description suggests this was dead code, but it appears it would have been triggered if no UsageContext default was found. The new implementation in EngineArgs ensures max_num_batched_tokens is always set, but the specific, higher defaults for pooling/multimodal models are no longer applied anywhere.

Removing this could lead to a significant performance regression for these model types. Could you please confirm if this change is intended? If it's a mistake, this logic should be restored, perhaps within EngineArgs.get_batch_defaults.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case of no UsageContext is not normal usage of vLLM

Copy link
Member Author

@DarkLight1337 DarkLight1337 Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill @WoosukKwon @robertgshaw2-redhat correct me if I'm wrong about this

Expand All @@ -256,7 +227,6 @@ def __post_init__(self, is_encoder_decoder: bool) -> None:
self.max_num_batched_tokens,
)

self.chunked_prefill_enabled = self.enable_chunked_prefill
if self.max_num_partial_prefills > 1:
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
Expand All @@ -270,6 +240,14 @@ def __post_init__(self, is_encoder_decoder: bool) -> None:
self.long_prefill_token_threshold,
)

@property
def chunked_prefill_enabled(self) -> bool:
return self.enable_chunked_prefill

@chunked_prefill_enabled.setter
def chunked_prefill_enabled(self, value: bool):
self.enable_chunked_prefill = value

Comment on lines +235 to +242
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just remove this? It used to be init=False so it's not part of the normal API of SchedulerConfig

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

@model_validator(mode="after")
def _verify_args(self) -> Self:
if (
Expand Down
154 changes: 88 additions & 66 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1733,41 +1733,41 @@ def _check_feature_supported(self, model_config: ModelConfig):
)
_raise_unsupported_error(feature_name=name)

def _set_default_args(
self, usage_context: UsageContext, model_config: ModelConfig
) -> None:
"""Set Default Arguments for V1 Engine."""

# V1 uses chunked prefills and prefix caching by default
# for non-pooling tasks.
# For pooling tasks the default is False
@classmethod
def get_chunked_prefill_prefix_caching_defaults(
cls,
model_config: ModelConfig,
) -> tuple[bool, bool]:
if model_config.runner_type != "pooling":
self.enable_chunked_prefill = True

if self.enable_prefix_caching is None:
# Disable prefix caching default for hybrid models
# since the feature is still experimental.
if model_config.is_hybrid:
self.enable_prefix_caching = False
else:
self.enable_prefix_caching = True
default_chunked_prefill = True

# Disable prefix caching default for hybrid models
# since the feature is still experimental.
default_prefix_caching = not model_config.is_hybrid
else:
assert model_config.pooler_config is not None

pooling_type = model_config.pooler_config.pooling_type
is_causal = getattr(model_config.hf_config, "is_causal", True)
incremental_prefill_supported = (
pooling_type is not None
and pooling_type.lower() == "last"
and bool(is_causal)
and getattr(model_config.hf_config, "is_causal", True)
)

action = "Enabling" if incremental_prefill_supported else "Disabling"
default_chunked_prefill = incremental_prefill_supported
default_prefix_caching = incremental_prefill_supported

if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = incremental_prefill_supported
logger.info("(%s) chunked prefill by default", action)
if self.enable_prefix_caching is None:
self.enable_prefix_caching = incremental_prefill_supported
logger.info("(%s) prefix caching by default", action)
return default_chunked_prefill, default_prefix_caching

@classmethod
def get_batch_defaults(
cls,
world_size: int,
) -> tuple[dict[UsageContext | None, int], dict[UsageContext | None, int]]:
from vllm.usage.usage_lib import UsageContext

default_max_num_batched_tokens: dict[UsageContext | None, int]
default_max_num_seqs: dict[UsageContext | None, int]

# When no user override, set the default values based on the usage
# context.
Expand All @@ -1788,8 +1788,6 @@ def _set_default_args(
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
# throughput, see PR #17885 for more details.
# So here we do an extra device name check to prevent such regression.
from vllm.usage.usage_lib import UsageContext

if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
# For GPUs like H100 and MI300x, use larger default values.
default_max_num_batched_tokens = {
Expand All @@ -1813,22 +1811,26 @@ def _set_default_args(

# tpu specific default values.
if current_platform.is_tpu():
default_max_num_batched_tokens_tpu = {
UsageContext.LLM_CLASS: {
"V6E": 2048,
"V5E": 1024,
"V5P": 512,
},
UsageContext.OPENAI_API_SERVER: {
"V6E": 1024,
"V5E": 512,
"V5P": 256,
},
}
chip_name = current_platform.get_device_name()

if chip_name == "V6E":
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 2048,
UsageContext.OPENAI_API_SERVER: 1024,
}
elif chip_name == "V5E":
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 1024,
UsageContext.OPENAI_API_SERVER: 512,
}
elif chip_name == "V5P":
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 512,
UsageContext.OPENAI_API_SERVER: 256,
}

# cpu specific default values.
if current_platform.is_cpu():
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 4096 * world_size,
UsageContext.OPENAI_API_SERVER: 2048 * world_size,
Expand All @@ -1838,39 +1840,59 @@ def _set_default_args(
UsageContext.OPENAI_API_SERVER: 128 * world_size,
}

return default_max_num_batched_tokens, default_max_num_seqs

def _set_default_args(
self, usage_context: UsageContext, model_config: ModelConfig
) -> None:
"""Set Default Arguments for V1 Engine."""
(
default_chunked_prefill,
default_prefix_caching,
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)

if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = default_chunked_prefill

if model_config.runner_type == "pooling":
logger.info(
"%s chunked prefill by default",
"Enabling" if default_chunked_prefill else "Disabling",
)
if self.enable_prefix_caching is None:
self.enable_prefix_caching = default_prefix_caching

if model_config.runner_type == "pooling":
logger.info(
"%s chunked prefill by default",
"Enabling" if default_prefix_caching else "Disabling",
)

world_size = self.pipeline_parallel_size * self.tensor_parallel_size
(
default_max_num_batched_tokens,
default_max_num_seqs,
) = self.get_batch_defaults(world_size)

use_context_value = usage_context.value if usage_context else None
if (
self.max_num_batched_tokens is None
and usage_context in default_max_num_batched_tokens
):
if current_platform.is_tpu():
chip_name = current_platform.get_device_name()
if chip_name in default_max_num_batched_tokens_tpu[usage_context]:
self.max_num_batched_tokens = default_max_num_batched_tokens_tpu[
usage_context
][chip_name]
else:
self.max_num_batched_tokens = default_max_num_batched_tokens[
usage_context
]
else:
if not self.enable_chunked_prefill:
self.max_num_batched_tokens = model_config.max_model_len
else:
self.max_num_batched_tokens = default_max_num_batched_tokens[
usage_context
]
if self.max_num_batched_tokens is None:
self.max_num_batched_tokens = default_max_num_batched_tokens.get(
usage_context,
SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)

logger.debug(
"Setting max_num_batched_tokens to %d for %s usage context.",
self.max_num_batched_tokens,
use_context_value,
)

if self.max_num_seqs is None and usage_context in default_max_num_seqs:
self.max_num_seqs = min(
default_max_num_seqs[usage_context],
self.max_num_batched_tokens or sys.maxsize,
if self.max_num_seqs is None:
self.max_num_seqs = default_max_num_seqs.get(
usage_context,
SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
)
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)

logger.debug(
"Setting max_num_seqs to %d for %s usage context.",
Expand Down
Loading
Loading