Skip to content

Commit 48f4f4f

Browse files
DarkLight1337bwasti
authored andcommitted
[Config] Clean up SchedulerConfig initialization (vllm-project#28665)
Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Bram Wasti <[email protected]>
1 parent fd1b584 commit 48f4f4f

File tree

9 files changed

+181
-162
lines changed

9 files changed

+181
-162
lines changed

tests/models/language/generation/test_hybrid.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,14 @@ def test_fp32_cache_state(
348348

349349

350350
# Helper functions for the APC tests
351-
def _get_vllm_runner_params(model, max_model_len, tensor_parallel_size=1):
351+
def _get_vllm_runner_params(
352+
model: str,
353+
max_model_len: int,
354+
tensor_parallel_size: int = 1,
355+
):
352356
return {
353357
"model_name": model,
358+
"enable_chunked_prefill": True,
354359
"enable_prefix_caching": False,
355360
"max_model_len": max_model_len,
356361
"tensor_parallel_size": tensor_parallel_size,

tests/v1/core/test_scheduler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2256,6 +2256,8 @@ def test_chunked_prefill_disabled_for_encoder_decoder(
22562256
scheduler_config = SchedulerConfig(
22572257
enable_chunked_prefill=enable_chunked_prefill,
22582258
is_encoder_decoder=is_encoder_decoder,
2259+
# Must <= max_num_batched_tokens if chunked prefill is disabled
2260+
max_model_len=SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
22592261
)
22602262

22612263
# `is_encoder_decoder` should only be used during construction

tests/v1/sample/test_logprobs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
4747
max_num_batched_tokens=16,
4848
max_num_seqs=16,
4949
max_model_len=128,
50+
enable_chunked_prefill=True,
5051
enforce_eager=True,
5152
# TODO: enable this once we support it for
5253
# prompt logprobs.

vllm/config/scheduler.py

Lines changed: 33 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,14 @@
44
import hashlib
55
from collections.abc import Callable
66
from dataclasses import InitVar
7-
from typing import TYPE_CHECKING, Any, Literal, cast
7+
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
88

99
from pydantic import Field, field_validator, model_validator
1010
from pydantic.dataclasses import dataclass
1111
from typing_extensions import Self
1212

1313
from vllm.config.utils import config
1414
from vllm.logger import init_logger
15-
from vllm.utils import (
16-
DEFAULT_MAX_NUM_BATCHED_TOKENS,
17-
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
18-
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
19-
)
2015
from vllm.utils.import_utils import resolve_obj_by_qualname
2116

2217
if TYPE_CHECKING:
@@ -33,25 +28,32 @@
3328
class SchedulerConfig:
3429
"""Scheduler configuration."""
3530

31+
DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048
32+
DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128
33+
3634
runner_type: RunnerType = "generate"
3735
"""The runner type to launch for the model."""
3836

39-
max_num_batched_tokens: int = Field(default=None, ge=1)
37+
max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1)
4038
"""Maximum number of tokens to be processed in a single iteration.
4139
42-
This config has no static default. If left unspecified by the user, it will
43-
be set in `EngineArgs.create_engine_config` based on the usage context."""
40+
The default value here is mainly for convenience when testing.
41+
In real usage, this should be set in `EngineArgs.create_engine_config`.
42+
"""
4443

45-
max_num_seqs: int = Field(default=None, ge=1)
44+
max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1)
4645
"""Maximum number of sequences to be processed in a single iteration.
4746
48-
This config has no static default. If left unspecified by the user, it will
49-
be set in `EngineArgs.create_engine_config` based on the usage context."""
47+
The default value here is mainly for convenience when testing.
48+
In real usage, this should be set in `EngineArgs.create_engine_config`.
49+
"""
5050

51-
max_model_len: int = Field(default=None, ge=1)
52-
"""Maximum length of a sequence (including prompt and generated text). This
53-
is primarily set in `ModelConfig` and that value should be manually
54-
duplicated here."""
51+
max_model_len: int = Field(default=8192, ge=1)
52+
"""Maximum length of a sequence (including prompt and generated text).
53+
54+
The default value here is mainly for convenience when testing.
55+
In real usage, this should duplicate `ModelConfig.max_model_len` via
56+
`EngineArgs`."""
5557

5658
max_num_partial_prefills: int = Field(default=1, ge=1)
5759
"""For chunked prefill, the maximum number of sequences that can be
@@ -76,9 +78,13 @@ class SchedulerConfig:
7678
NOTE: This will be replaced by speculative config in the future; it is
7779
present to enable correctness tests until then."""
7880

79-
enable_chunked_prefill: bool = Field(default=None)
81+
enable_chunked_prefill: bool = True
8082
"""If True, prefill requests can be chunked based
81-
on the remaining max_num_batched_tokens."""
83+
on the remaining `max_num_batched_tokens`.
84+
85+
The default value here is mainly for convenience when testing.
86+
In real usage, this should be set in `EngineArgs.create_engine_config`.
87+
"""
8288

8389
is_multimodal_model: bool = False
8490
"""True if the model is multimodal."""
@@ -111,9 +117,6 @@ class SchedulerConfig:
111117
- "priority" means requests are handled based on given priority (lower
112118
value means earlier handling) and time of arrival deciding any ties)."""
113119

114-
chunked_prefill_enabled: bool = Field(init=False)
115-
"""True if chunked prefill is enabled."""
116-
117120
disable_chunked_mm_input: bool = False
118121
"""If set to true and chunked prefill is enabled, we do not want to
119122
partially schedule a multimodal item. Only used in V1
@@ -188,15 +191,7 @@ def compute_hash(self) -> str:
188191
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
189192
return hash_str
190193

191-
@field_validator(
192-
"max_num_batched_tokens",
193-
"max_num_seqs",
194-
"max_model_len",
195-
"enable_chunked_prefill",
196-
"scheduler_cls",
197-
"async_scheduling",
198-
mode="wrap",
199-
)
194+
@field_validator("scheduler_cls", "async_scheduling", mode="wrap")
200195
@classmethod
201196
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
202197
"""Skip validation if the value is `None` when initialisation is delayed."""
@@ -205,54 +200,16 @@ def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
205200
return handler(value)
206201

207202
def __post_init__(self, is_encoder_decoder: bool) -> None:
208-
if self.max_model_len is None:
209-
self.max_model_len = 8192
210-
211-
if self.max_num_seqs is None:
212-
self.max_num_seqs = 128
213-
214203
if is_encoder_decoder:
215204
# Chunked prefill should be disabled for encoder-decoder models.
216205
self.disable_chunked_mm_input = True
217-
self.chunked_prefill_enabled = False
218206
self.enable_chunked_prefill = False
219207
self.long_prefill_token_threshold = 0
220208
logger.info(
221209
"Encoder-decoder models do not support chunked prefill nor"
222210
" prefix caching; disabling both."
223211
)
224212

225-
if self.max_num_batched_tokens is None:
226-
if self.enable_chunked_prefill:
227-
self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS
228-
else:
229-
# If max_model_len is too short, use
230-
# DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
231-
# for higher throughput.
232-
self.max_num_batched_tokens = max(
233-
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS
234-
)
235-
236-
if self.runner_type == "pooling":
237-
# Choose specific value for higher throughput
238-
self.max_num_batched_tokens = max(
239-
self.max_num_batched_tokens,
240-
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
241-
)
242-
if self.is_multimodal_model:
243-
# The value needs to be at least the number of multimodal tokens
244-
self.max_num_batched_tokens = max(
245-
self.max_num_batched_tokens,
246-
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
247-
)
248-
249-
# When using default settings,
250-
# Ensure max_num_batched_tokens does not exceed model limit.
251-
# Some models (e.g., Whisper) have embeddings tied to max length.
252-
self.max_num_batched_tokens = min(
253-
self.max_num_seqs * self.max_model_len, self.max_num_batched_tokens
254-
)
255-
256213
self.max_num_encoder_input_tokens = self.max_num_batched_tokens
257214
self.encoder_cache_size = self.max_num_batched_tokens
258215

@@ -262,7 +219,6 @@ def __post_init__(self, is_encoder_decoder: bool) -> None:
262219
self.max_num_batched_tokens,
263220
)
264221

265-
self.chunked_prefill_enabled = self.enable_chunked_prefill
266222
if self.max_num_partial_prefills > 1:
267223
if self.long_prefill_token_threshold == 0:
268224
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
@@ -276,6 +232,14 @@ def __post_init__(self, is_encoder_decoder: bool) -> None:
276232
self.long_prefill_token_threshold,
277233
)
278234

235+
@property
236+
def chunked_prefill_enabled(self) -> bool:
237+
return self.enable_chunked_prefill
238+
239+
@chunked_prefill_enabled.setter
240+
def chunked_prefill_enabled(self, value: bool):
241+
self.enable_chunked_prefill = value
242+
279243
@model_validator(mode="after")
280244
def _verify_args(self) -> Self:
281245
if (

0 commit comments

Comments
 (0)