Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 42 additions & 5 deletions tests/plugins_tests/test_scheduler_plugins.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
# SPDX-License-Identifier: Apache-2.0

from unittest import mock

import pytest

from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine


class DummyScheduler(Scheduler):
Expand All @@ -9,13 +18,14 @@ def schedule(self):
raise Exception("Exception raised by DummyScheduler")


def test_scheduler_plugins():
import pytest
class DummyV1Scheduler(V1Scheduler):

from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.sampling_params import SamplingParams
def schedule(self):
raise Exception("Exception raised by DummyScheduler")


def test_scheduler_plugins(monkeypatch):
monkeypatch.setenv("VLLM_USE_V1", "0")
with pytest.raises(Exception) as exception_info:

engine_args = EngineArgs(
Expand All @@ -31,3 +41,30 @@ def test_scheduler_plugins():
engine.step()

assert str(exception_info.value) == "Exception raised by DummyScheduler"


def test_scheduler_plugins_v1(monkeypatch):
monkeypatch.setenv("VLLM_USE_V1", "1")

# V1 engine has more redirection- the worker process will raise with our
# dummy scheduler error but then client in this process will try to kill
# the process tree when the worker fails.
with mock.patch(
"vllm.v1.engine.core_client.kill_process_tree") as mock_kill:
mock_kill.side_effect = Exception("kill_process_tree was called")

with pytest.raises(Exception) as exception_info:

engine_args = EngineArgs(
model="facebook/opt-125m",
enforce_eager=True, # reduce test time
scheduler_cls=DummyV1Scheduler,
)

engine = V1LLMEngine.from_engine_args(engine_args=engine_args)

sampling_params = SamplingParams(max_tokens=1)
engine.add_request("0", "foo", sampling_params)
engine.step()

assert str(exception_info.value) == "kill_process_tree was called"
5 changes: 5 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,11 @@ def _override_v1_engine_args(self, usage_context: UsageContext) -> None:

# V1 always uses chunked prefills.
self.enable_chunked_prefill = True
# V1 should use the new scheduler by default.
# Swap it only if this arg is set to the original V0 default
if self.scheduler_cls == EngineArgs.scheduler_cls:
self.scheduler_cls = "vllm.v1.core.scheduler.Scheduler"

# When no user override, set the default values based on the usage
# context.
# Use different default values for different hardware.
Expand Down
10 changes: 8 additions & 2 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
zmq_socket_ctx)
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
Expand Down Expand Up @@ -65,6 +66,11 @@ def __init__(
self.structured_output_manager = StructuredOutputManager(vllm_config)

# Setup scheduler.
if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
Scheduler = resolve_obj_by_qualname(
vllm_config.scheduler_config.scheduler_cls)
else:
Scheduler = vllm_config.scheduler_config.scheduler_cls
self.scheduler = Scheduler(
scheduler_config=vllm_config.scheduler_config,
model_config=vllm_config.model_config,
Expand Down