Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,14 @@ def test_fp32_cache_state(


# Helper functions for the APC tests
def _get_vllm_runner_params(model, max_model_len, tensor_parallel_size=1):
def _get_vllm_runner_params(
model: str,
max_model_len: int,
tensor_parallel_size: int = 1,
):
return {
"model_name": model,
"enable_chunked_prefill": True,
"enable_prefix_caching": False,
"max_model_len": max_model_len,
"tensor_parallel_size": tensor_parallel_size,
Expand Down
2 changes: 2 additions & 0 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2256,6 +2256,8 @@ def test_chunked_prefill_disabled_for_encoder_decoder(
scheduler_config = SchedulerConfig(
enable_chunked_prefill=enable_chunked_prefill,
is_encoder_decoder=is_encoder_decoder,
# Must <= max_num_batched_tokens if chunked prefill is disabled
max_model_len=SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)

# `is_encoder_decoder` should only be used during construction
Expand Down
1 change: 1 addition & 0 deletions tests/v1/sample/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
max_num_batched_tokens=16,
max_num_seqs=16,
max_model_len=128,
enable_chunked_prefill=True,
enforce_eager=True,
# TODO: enable this once we support it for
# prompt logprobs.
Expand Down
102 changes: 33 additions & 69 deletions vllm/config/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,14 @@
import hashlib
from collections.abc import Callable
from dataclasses import InitVar
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast

from pydantic import Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self

from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils import (
DEFAULT_MAX_NUM_BATCHED_TOKENS,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
from vllm.utils.import_utils import resolve_obj_by_qualname

if TYPE_CHECKING:
Expand All @@ -33,25 +28,32 @@
class SchedulerConfig:
"""Scheduler configuration."""

DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048
DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128

runner_type: RunnerType = "generate"
"""The runner type to launch for the model."""

max_num_batched_tokens: int = Field(default=None, ge=1)
max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1)
"""Maximum number of tokens to be processed in a single iteration.

This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
The default value here is mainly for convenience when testing.
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""

max_num_seqs: int = Field(default=None, ge=1)
max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1)
"""Maximum number of sequences to be processed in a single iteration.

This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
The default value here is mainly for convenience when testing.
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""

max_model_len: int = Field(default=None, ge=1)
"""Maximum length of a sequence (including prompt and generated text). This
is primarily set in `ModelConfig` and that value should be manually
duplicated here."""
max_model_len: int = Field(default=8192, ge=1)
"""Maximum length of a sequence (including prompt and generated text).

The default value here is mainly for convenience when testing.
In real usage, this should duplicate `ModelConfig.max_model_len` via
`EngineArgs`."""
Comment on lines +51 to +56
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we remove this entirely?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used in some other places like vllm.v1.core.sched.Scheduler. We can try to refactor this in another PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this would need a small refactor. A follow up PR sounds good.


max_num_partial_prefills: int = Field(default=1, ge=1)
"""For chunked prefill, the maximum number of sequences that can be
Expand All @@ -76,9 +78,13 @@ class SchedulerConfig:
NOTE: This will be replaced by speculative config in the future; it is
present to enable correctness tests until then."""

enable_chunked_prefill: bool = Field(default=None)
enable_chunked_prefill: bool = True
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
on the remaining `max_num_batched_tokens`.

The default value here is mainly for convenience when testing.
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""

is_multimodal_model: bool = False
"""True if the model is multimodal."""
Expand Down Expand Up @@ -111,9 +117,6 @@ class SchedulerConfig:
- "priority" means requests are handled based on given priority (lower
value means earlier handling) and time of arrival deciding any ties)."""

chunked_prefill_enabled: bool = Field(init=False)
"""True if chunked prefill is enabled."""

disable_chunked_mm_input: bool = False
"""If set to true and chunked prefill is enabled, we do not want to
partially schedule a multimodal item. Only used in V1
Expand Down Expand Up @@ -188,15 +191,7 @@ def compute_hash(self) -> str:
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str

@field_validator(
"max_num_batched_tokens",
"max_num_seqs",
"max_model_len",
"enable_chunked_prefill",
"scheduler_cls",
"async_scheduling",
mode="wrap",
)
@field_validator("scheduler_cls", "async_scheduling", mode="wrap")
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialisation is delayed."""
Expand All @@ -205,54 +200,16 @@ def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
return handler(value)

def __post_init__(self, is_encoder_decoder: bool) -> None:
if self.max_model_len is None:
self.max_model_len = 8192

if self.max_num_seqs is None:
self.max_num_seqs = 128

if is_encoder_decoder:
# Chunked prefill should be disabled for encoder-decoder models.
self.disable_chunked_mm_input = True
self.chunked_prefill_enabled = False
self.enable_chunked_prefill = False
self.long_prefill_token_threshold = 0
logger.info(
"Encoder-decoder models do not support chunked prefill nor"
" prefix caching; disabling both."
)

if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill:
self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS
else:
# If max_model_len is too short, use
# DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
# for higher throughput.
self.max_num_batched_tokens = max(
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS
)

if self.runner_type == "pooling":
# Choose specific value for higher throughput
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
if self.is_multimodal_model:
# The value needs to be at least the number of multimodal tokens
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
)

# When using default settings,
# Ensure max_num_batched_tokens does not exceed model limit.
# Some models (e.g., Whisper) have embeddings tied to max length.
self.max_num_batched_tokens = min(
self.max_num_seqs * self.max_model_len, self.max_num_batched_tokens
)

self.max_num_encoder_input_tokens = self.max_num_batched_tokens
self.encoder_cache_size = self.max_num_batched_tokens

Expand All @@ -262,7 +219,6 @@ def __post_init__(self, is_encoder_decoder: bool) -> None:
self.max_num_batched_tokens,
)

self.chunked_prefill_enabled = self.enable_chunked_prefill
if self.max_num_partial_prefills > 1:
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
Expand All @@ -276,6 +232,14 @@ def __post_init__(self, is_encoder_decoder: bool) -> None:
self.long_prefill_token_threshold,
)

@property
def chunked_prefill_enabled(self) -> bool:
return self.enable_chunked_prefill

@chunked_prefill_enabled.setter
def chunked_prefill_enabled(self, value: bool):
self.enable_chunked_prefill = value

Comment on lines +235 to +242
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just remove this? It used to be init=False so it's not part of the normal API of SchedulerConfig

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

@model_validator(mode="after")
def _verify_args(self) -> Self:
if (
Expand Down
Loading