diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 681b380e6a15..37830093cd3c 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -348,9 +348,14 @@ def test_fp32_cache_state( # Helper functions for the APC tests -def _get_vllm_runner_params(model, max_model_len, tensor_parallel_size=1): +def _get_vllm_runner_params( + model: str, + max_model_len: int, + tensor_parallel_size: int = 1, +): return { "model_name": model, + "enable_chunked_prefill": True, "enable_prefix_caching": False, "max_model_len": max_model_len, "tensor_parallel_size": tensor_parallel_size, diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index d31338220fca..287e735b5491 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -2256,6 +2256,8 @@ def test_chunked_prefill_disabled_for_encoder_decoder( scheduler_config = SchedulerConfig( enable_chunked_prefill=enable_chunked_prefill, is_encoder_decoder=is_encoder_decoder, + # Must <= max_num_batched_tokens if chunked prefill is disabled + max_model_len=SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS, ) # `is_encoder_decoder` should only be used during construction diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 354fff22dc2a..42584938bc06 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -47,6 +47,7 @@ def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]: max_num_batched_tokens=16, max_num_seqs=16, max_model_len=128, + enable_chunked_prefill=True, enforce_eager=True, # TODO: enable this once we support it for # prompt logprobs. diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 71a06e167fd9..5117344a6844 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -4,7 +4,7 @@ import hashlib from collections.abc import Callable from dataclasses import InitVar -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast from pydantic import Field, field_validator, model_validator from pydantic.dataclasses import dataclass @@ -12,11 +12,6 @@ from vllm.config.utils import config from vllm.logger import init_logger -from vllm.utils import ( - DEFAULT_MAX_NUM_BATCHED_TOKENS, - MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, - POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, -) from vllm.utils.import_utils import resolve_obj_by_qualname if TYPE_CHECKING: @@ -33,25 +28,32 @@ class SchedulerConfig: """Scheduler configuration.""" + DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048 + DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128 + runner_type: RunnerType = "generate" """The runner type to launch for the model.""" - max_num_batched_tokens: int = Field(default=None, ge=1) + max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1) """Maximum number of tokens to be processed in a single iteration. - This config has no static default. If left unspecified by the user, it will - be set in `EngineArgs.create_engine_config` based on the usage context.""" + The default value here is mainly for convenience when testing. + In real usage, this should be set in `EngineArgs.create_engine_config`. + """ - max_num_seqs: int = Field(default=None, ge=1) + max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1) """Maximum number of sequences to be processed in a single iteration. - This config has no static default. If left unspecified by the user, it will - be set in `EngineArgs.create_engine_config` based on the usage context.""" + The default value here is mainly for convenience when testing. + In real usage, this should be set in `EngineArgs.create_engine_config`. + """ - max_model_len: int = Field(default=None, ge=1) - """Maximum length of a sequence (including prompt and generated text). This - is primarily set in `ModelConfig` and that value should be manually - duplicated here.""" + max_model_len: int = Field(default=8192, ge=1) + """Maximum length of a sequence (including prompt and generated text). + + The default value here is mainly for convenience when testing. + In real usage, this should duplicate `ModelConfig.max_model_len` via + `EngineArgs`.""" max_num_partial_prefills: int = Field(default=1, ge=1) """For chunked prefill, the maximum number of sequences that can be @@ -76,9 +78,13 @@ class SchedulerConfig: NOTE: This will be replaced by speculative config in the future; it is present to enable correctness tests until then.""" - enable_chunked_prefill: bool = Field(default=None) + enable_chunked_prefill: bool = True """If True, prefill requests can be chunked based - on the remaining max_num_batched_tokens.""" + on the remaining `max_num_batched_tokens`. + + The default value here is mainly for convenience when testing. + In real usage, this should be set in `EngineArgs.create_engine_config`. + """ is_multimodal_model: bool = False """True if the model is multimodal.""" @@ -111,9 +117,6 @@ class SchedulerConfig: - "priority" means requests are handled based on given priority (lower value means earlier handling) and time of arrival deciding any ties).""" - chunked_prefill_enabled: bool = Field(init=False) - """True if chunked prefill is enabled.""" - disable_chunked_mm_input: bool = False """If set to true and chunked prefill is enabled, we do not want to partially schedule a multimodal item. Only used in V1 @@ -188,15 +191,7 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str - @field_validator( - "max_num_batched_tokens", - "max_num_seqs", - "max_model_len", - "enable_chunked_prefill", - "scheduler_cls", - "async_scheduling", - mode="wrap", - ) + @field_validator("scheduler_cls", "async_scheduling", mode="wrap") @classmethod def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: """Skip validation if the value is `None` when initialisation is delayed.""" @@ -205,16 +200,9 @@ def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: return handler(value) def __post_init__(self, is_encoder_decoder: bool) -> None: - if self.max_model_len is None: - self.max_model_len = 8192 - - if self.max_num_seqs is None: - self.max_num_seqs = 128 - if is_encoder_decoder: # Chunked prefill should be disabled for encoder-decoder models. self.disable_chunked_mm_input = True - self.chunked_prefill_enabled = False self.enable_chunked_prefill = False self.long_prefill_token_threshold = 0 logger.info( @@ -222,37 +210,6 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: " prefix caching; disabling both." ) - if self.max_num_batched_tokens is None: - if self.enable_chunked_prefill: - self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS - else: - # If max_model_len is too short, use - # DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value - # for higher throughput. - self.max_num_batched_tokens = max( - self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS - ) - - if self.runner_type == "pooling": - # Choose specific value for higher throughput - self.max_num_batched_tokens = max( - self.max_num_batched_tokens, - POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, - ) - if self.is_multimodal_model: - # The value needs to be at least the number of multimodal tokens - self.max_num_batched_tokens = max( - self.max_num_batched_tokens, - MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, - ) - - # When using default settings, - # Ensure max_num_batched_tokens does not exceed model limit. - # Some models (e.g., Whisper) have embeddings tied to max length. - self.max_num_batched_tokens = min( - self.max_num_seqs * self.max_model_len, self.max_num_batched_tokens - ) - self.max_num_encoder_input_tokens = self.max_num_batched_tokens self.encoder_cache_size = self.max_num_batched_tokens @@ -262,7 +219,6 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: self.max_num_batched_tokens, ) - self.chunked_prefill_enabled = self.enable_chunked_prefill if self.max_num_partial_prefills > 1: if self.long_prefill_token_threshold == 0: self.long_prefill_token_threshold = int(self.max_model_len * 0.04) @@ -276,6 +232,14 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: self.long_prefill_token_threshold, ) + @property + def chunked_prefill_enabled(self) -> bool: + return self.enable_chunked_prefill + + @chunked_prefill_enabled.setter + def chunked_prefill_enabled(self, value: bool): + self.enable_chunked_prefill = value + @model_validator(mode="after") def _verify_args(self) -> Self: if ( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b025004ea022..cacebc530b6e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -428,11 +428,11 @@ class EngineArgs: cpu_offload_gb: float = CacheConfig.cpu_offload_gb gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes - max_num_batched_tokens: int | None = SchedulerConfig.max_num_batched_tokens + max_num_batched_tokens: int | None = None max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold - max_num_seqs: int | None = SchedulerConfig.max_num_seqs + max_num_seqs: int | None = None max_logprobs: int = ModelConfig.max_logprobs logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode disable_log_stats: bool = False @@ -485,7 +485,7 @@ class EngineArgs: model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config") ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns") - enable_chunked_prefill: bool | None = SchedulerConfig.enable_chunked_prefill + enable_chunked_prefill: bool | None = None disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input disable_hybrid_kv_cache_manager: bool = ( @@ -1738,41 +1738,41 @@ def _check_feature_supported(self, model_config: ModelConfig): ) _raise_unsupported_error(feature_name=name) - def _set_default_args( - self, usage_context: UsageContext, model_config: ModelConfig - ) -> None: - """Set Default Arguments for V1 Engine.""" - - # V1 uses chunked prefills and prefix caching by default - # for non-pooling tasks. - # For pooling tasks the default is False + @classmethod + def get_chunked_prefill_prefix_caching_defaults( + cls, + model_config: ModelConfig, + ) -> tuple[bool, bool]: if model_config.runner_type != "pooling": - self.enable_chunked_prefill = True - - if self.enable_prefix_caching is None: - # Disable prefix caching default for hybrid models - # since the feature is still experimental. - if model_config.is_hybrid: - self.enable_prefix_caching = False - else: - self.enable_prefix_caching = True + default_chunked_prefill = True + + # Disable prefix caching default for hybrid models + # since the feature is still experimental. + default_prefix_caching = not model_config.is_hybrid else: + assert model_config.pooler_config is not None + pooling_type = model_config.pooler_config.pooling_type - is_causal = getattr(model_config.hf_config, "is_causal", True) incremental_prefill_supported = ( pooling_type is not None and pooling_type.lower() == "last" - and bool(is_causal) + and getattr(model_config.hf_config, "is_causal", True) ) - action = "Enabling" if incremental_prefill_supported else "Disabling" + default_chunked_prefill = incremental_prefill_supported + default_prefix_caching = incremental_prefill_supported + + return default_chunked_prefill, default_prefix_caching + + @classmethod + def get_batch_defaults( + cls, + world_size: int, + ) -> tuple[dict[UsageContext | None, int], dict[UsageContext | None, int]]: + from vllm.usage.usage_lib import UsageContext - if self.enable_chunked_prefill is None: - self.enable_chunked_prefill = incremental_prefill_supported - logger.info("(%s) chunked prefill by default", action) - if self.enable_prefix_caching is None: - self.enable_prefix_caching = incremental_prefill_supported - logger.info("(%s) prefix caching by default", action) + default_max_num_batched_tokens: dict[UsageContext | None, int] + default_max_num_seqs: dict[UsageContext | None, int] # When no user override, set the default values based on the usage # context. @@ -1793,8 +1793,6 @@ def _set_default_args( # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces # throughput, see PR #17885 for more details. # So here we do an extra device name check to prevent such regression. - from vllm.usage.usage_lib import UsageContext - if device_memory >= 70 * GiB_bytes and "a100" not in device_name: # For GPUs like H100 and MI300x, use larger default values. default_max_num_batched_tokens = { @@ -1818,22 +1816,26 @@ def _set_default_args( # tpu specific default values. if current_platform.is_tpu(): - default_max_num_batched_tokens_tpu = { - UsageContext.LLM_CLASS: { - "V6E": 2048, - "V5E": 1024, - "V5P": 512, - }, - UsageContext.OPENAI_API_SERVER: { - "V6E": 1024, - "V5E": 512, - "V5P": 256, - }, - } + chip_name = current_platform.get_device_name() + + if chip_name == "V6E": + default_max_num_batched_tokens = { + UsageContext.LLM_CLASS: 2048, + UsageContext.OPENAI_API_SERVER: 1024, + } + elif chip_name == "V5E": + default_max_num_batched_tokens = { + UsageContext.LLM_CLASS: 1024, + UsageContext.OPENAI_API_SERVER: 512, + } + elif chip_name == "V5P": + default_max_num_batched_tokens = { + UsageContext.LLM_CLASS: 512, + UsageContext.OPENAI_API_SERVER: 256, + } # cpu specific default values. if current_platform.is_cpu(): - world_size = self.pipeline_parallel_size * self.tensor_parallel_size default_max_num_batched_tokens = { UsageContext.LLM_CLASS: 4096 * world_size, UsageContext.OPENAI_API_SERVER: 2048 * world_size, @@ -1843,44 +1845,104 @@ def _set_default_args( UsageContext.OPENAI_API_SERVER: 128 * world_size, } - use_context_value = usage_context.value if usage_context else None - if ( - self.max_num_batched_tokens is None - and usage_context in default_max_num_batched_tokens + return default_max_num_batched_tokens, default_max_num_seqs + + def _set_default_args( + self, usage_context: UsageContext, model_config: ModelConfig + ) -> None: + """Set Default Arguments for V1 Engine.""" + ( + default_chunked_prefill, + default_prefix_caching, + ) = self.get_chunked_prefill_prefix_caching_defaults(model_config) + + if self.enable_chunked_prefill is None: + self.enable_chunked_prefill = default_chunked_prefill + + logger.debug( + "%s chunked prefill by default", + "Enabling" if default_chunked_prefill else "Disabling", + ) + elif ( + model_config.runner_type == "pooling" + and self.enable_chunked_prefill + and not default_chunked_prefill ): - if current_platform.is_tpu(): - chip_name = current_platform.get_device_name() - if chip_name in default_max_num_batched_tokens_tpu[usage_context]: - self.max_num_batched_tokens = default_max_num_batched_tokens_tpu[ - usage_context - ][chip_name] - else: - self.max_num_batched_tokens = default_max_num_batched_tokens[ - usage_context - ] - else: - if not self.enable_chunked_prefill: - self.max_num_batched_tokens = model_config.max_model_len - else: - self.max_num_batched_tokens = default_max_num_batched_tokens[ - usage_context - ] + logger.warning( + "This model does not officially support chunked prefill. " + "Enabling this manually may cause the engine to crash " + "or produce incorrect outputs.", + ) + + if self.enable_prefix_caching is None: + self.enable_prefix_caching = default_prefix_caching + logger.debug( - "Setting max_num_batched_tokens to %d for %s usage context.", + "%s prefix caching by default", + "Enabling" if default_prefix_caching else "Disabling", + ) + elif ( + model_config.runner_type == "pooling" + and self.enable_prefix_caching + and not default_prefix_caching + ): + logger.warning( + "This model does not officially support prefix caching. " + "Enabling this manually may cause the engine to crash " + "or produce incorrect outputs.", + ) + + world_size = self.pipeline_parallel_size * self.tensor_parallel_size + ( + default_max_num_batched_tokens, + default_max_num_seqs, + ) = self.get_batch_defaults(world_size) + + orig_max_num_batched_tokens = self.max_num_batched_tokens + orig_max_num_seqs = self.max_num_seqs + + if self.max_num_batched_tokens is None: + self.max_num_batched_tokens = default_max_num_batched_tokens.get( + usage_context, + SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS, + ) + + if self.max_num_seqs is None: + self.max_num_seqs = default_max_num_seqs.get( + usage_context, + SchedulerConfig.DEFAULT_MAX_NUM_SEQS, + ) + + if orig_max_num_batched_tokens is None: + if not self.enable_chunked_prefill: + # If max_model_len is too short, use the default for higher throughput. + self.max_num_batched_tokens = max( + model_config.max_model_len, + self.max_num_batched_tokens, + ) + + # When using default settings, + # Ensure max_num_batched_tokens does not exceed model limit. + # Some models (e.g., Whisper) have embeddings tied to max length. + self.max_num_batched_tokens = min( + self.max_num_seqs * model_config.max_model_len, self.max_num_batched_tokens, - use_context_value, ) - if self.max_num_seqs is None and usage_context in default_max_num_seqs: - self.max_num_seqs = min( - default_max_num_seqs[usage_context], - self.max_num_batched_tokens or sys.maxsize, + logger.debug( + "Defaulting max_num_batched_tokens to %d for %s usage context.", + self.max_num_batched_tokens, + usage_context.value if usage_context else None, ) + if orig_max_num_seqs is None: + assert self.max_num_batched_tokens is not None # For type checking + self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens) + logger.debug( - "Setting max_num_seqs to %d for %s usage context.", + "Defaulting max_num_seqs to %d for %s usage context.", self.max_num_seqs, - use_context_value, + usage_context.value if usage_context else None, ) diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index cf954768689f..fdfa1c19789c 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -15,7 +15,6 @@ from vllm import envs from vllm.logger import init_logger -from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import CpuArchEnum, Platform, PlatformEnum @@ -339,10 +338,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "prefill and prefix caching to be disabled." ) vllm_config.scheduler_config.enable_chunked_prefill = False - vllm_config.scheduler_config.chunked_prefill_enabled = False vllm_config.scheduler_config.max_num_batched_tokens = max( vllm_config.scheduler_config.max_model_len, - DEFAULT_MAX_NUM_BATCHED_TOKENS, + vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS, ) @classmethod diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index b997bb9e6999..4ab037fdb77e 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -10,7 +10,6 @@ from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import Platform, PlatformEnum @@ -186,10 +185,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "prefill and prefix caching to be disabled." ) vllm_config.scheduler_config.enable_chunked_prefill = False - vllm_config.scheduler_config.chunked_prefill_enabled = False vllm_config.scheduler_config.max_num_batched_tokens = max( vllm_config.scheduler_config.max_model_len, - DEFAULT_MAX_NUM_BATCHED_TOKENS, + vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS, ) @classmethod diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 5552e4ca4b2f..ad4beb28bdae 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -9,7 +9,6 @@ import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import DeviceCapability, Platform, PlatformEnum @@ -185,10 +184,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "prefill and prefix caching to be disabled." ) vllm_config.scheduler_config.enable_chunked_prefill = False - vllm_config.scheduler_config.chunked_prefill_enabled = False vllm_config.scheduler_config.max_num_batched_tokens = max( vllm_config.scheduler_config.max_model_len, - DEFAULT_MAX_NUM_BATCHED_TOKENS, + vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS, ) @classmethod diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 040c0416c5ea..3ef44e770320 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3,7 +3,7 @@ import uuid import warnings -from typing import Any, TypeVar +from typing import Any import torch @@ -39,12 +39,6 @@ def __dir__() -> list[str]: logger = init_logger(__name__) -# This value is chosen to have a balance between ITL and TTFT. Note it is -# not optimized for throughput. -DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048 -POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 -MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 - # Constants related to forcing the attention backend selection # String name of register which may be set in order to @@ -60,9 +54,6 @@ def __dir__() -> list[str]: STR_INVALID_VAL: str = "INVALID" -T = TypeVar("T") - - def random_uuid() -> str: return str(uuid.uuid4().hex)