diff --git a/tests/plugins_tests/test_scheduler_plugins.py b/tests/plugins_tests/test_scheduler_plugins.py index 84688cee9660..98981a81e909 100644 --- a/tests/plugins_tests/test_scheduler_plugins.py +++ b/tests/plugins_tests/test_scheduler_plugins.py @@ -1,27 +1,35 @@ # SPDX-License-Identifier: Apache-2.0 +import pytest + from vllm.core.scheduler import Scheduler +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.sampling_params import SamplingParams +from vllm.v1.core.scheduler import Scheduler as V1Scheduler +from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine -class DummyScheduler(Scheduler): +class DummyV0Scheduler(Scheduler): def schedule(self): - raise Exception("Exception raised by DummyScheduler") + raise Exception("Exception raised by DummyV0Scheduler") + +class DummyV1Scheduler(V1Scheduler): -def test_scheduler_plugins(): - import pytest + def schedule(self): + raise Exception("Exception raised by DummyV1Scheduler") - from vllm.engine.arg_utils import EngineArgs - from vllm.engine.llm_engine import LLMEngine - from vllm.sampling_params import SamplingParams +def test_scheduler_plugins_v0(monkeypatch): + monkeypatch.setenv("VLLM_USE_V1", "0") with pytest.raises(Exception) as exception_info: engine_args = EngineArgs( model="facebook/opt-125m", enforce_eager=True, # reduce test time - scheduler_cls=DummyScheduler, + scheduler_cls=DummyV0Scheduler, ) engine = LLMEngine.from_engine_args(engine_args=engine_args) @@ -30,4 +38,27 @@ def test_scheduler_plugins(): engine.add_request("0", "foo", sampling_params) engine.step() - assert str(exception_info.value) == "Exception raised by DummyScheduler" + assert str(exception_info.value) == "Exception raised by DummyV0Scheduler" + + +def test_scheduler_plugins_v1(monkeypatch): + monkeypatch.setenv("VLLM_USE_V1", "1") + # Explicitly turn off engine multiprocessing so that the scheduler runs in + # this process + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + with pytest.raises(Exception) as exception_info: + + engine_args = EngineArgs( + model="facebook/opt-125m", + enforce_eager=True, # reduce test time + scheduler_cls=DummyV1Scheduler, + ) + + engine = V1LLMEngine.from_engine_args(engine_args=engine_args) + + sampling_params = SamplingParams(max_tokens=1) + engine.add_request("0", "foo", sampling_params) + engine.step() + + assert str(exception_info.value) == "Exception raised by DummyV1Scheduler" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0d285acd15f3..c042171ae6c8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1437,6 +1437,11 @@ def _override_v1_engine_args(self, usage_context: UsageContext) -> None: # V1 always uses chunked prefills. self.enable_chunked_prefill = True + # V1 should use the new scheduler by default. + # Swap it only if this arg is set to the original V0 default + if self.scheduler_cls == EngineArgs.scheduler_cls: + self.scheduler_cls = "vllm.v1.core.scheduler.Scheduler" + # When no user override, set the default values based on the usage # context. # Use different default values for different hardware. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index bdf9203b1b1d..38ddaf05e38d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -19,9 +19,10 @@ from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import get_exception_traceback, zmq_socket_ctx +from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname, + zmq_socket_ctx) from vllm.v1.core.kv_cache_utils import get_kv_cache_configs -from vllm.v1.core.scheduler import Scheduler, SchedulerOutput +from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.mm_input_cache import MMInputCacheServer @@ -65,6 +66,16 @@ def __init__( self.structured_output_manager = StructuredOutputManager(vllm_config) # Setup scheduler. + if isinstance(vllm_config.scheduler_config.scheduler_cls, str): + logger.warning( + "Using configured V1 scheduler class %s. " + "This scheduler interface is not public and " + "compatibility may not be maintained.", + vllm_config.scheduler_config.scheduler_cls) + Scheduler = resolve_obj_by_qualname( + vllm_config.scheduler_config.scheduler_cls) + else: + Scheduler = vllm_config.scheduler_config.scheduler_cls self.scheduler = Scheduler( scheduler_config=vllm_config.scheduler_config, model_config=vllm_config.model_config,