Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion tools/pre_commit/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
"vllm/v1/core",
"vllm/v1/engine",
]

# After fixing errors resulting from changing follow_imports
# from "skip" to "silent", move the following directories to FILES
SEPARATE_GROUPS = [
"tests",
# v0 related
"vllm/attention",
"vllm/compilation",
"vllm/engine",
Expand All @@ -50,7 +53,16 @@
"vllm/model_executor",
"vllm/plugins",
"vllm/worker",
"vllm/v1",
# v1 related
"vllm/v1/attention",
"vllm/v1/executor",
"vllm/v1/kv_offload",
"vllm/v1/metrics",
"vllm/v1/pool",
"vllm/v1/sample",
"vllm/v1/spec_decode",
"vllm/v1/structured_output",
"vllm/v1/worker",
]

# TODO(woosuk): Include the code from Megatron and HuggingFace.
Expand Down
9 changes: 4 additions & 5 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ class VllmConfig:
default_factory=StructuredOutputsConfig
)
"""Structured outputs configuration."""
observability_config: ObservabilityConfig | None = None
observability_config: ObservabilityConfig = Field(
default_factory=ObservabilityConfig
)
"""Observability configuration."""
quant_config: QuantizationConfig | None = None
"""Quantization configuration."""
Expand Down Expand Up @@ -164,10 +166,7 @@ def compute_hash(self) -> str:
vllm_factors.append(self.structured_outputs_config.compute_hash())
else:
vllm_factors.append("None")
if self.observability_config:
vllm_factors.append(self.observability_config.compute_hash())
else:
vllm_factors.append("None")
vllm_factors.append(self.observability_config.compute_hash())
if self.quant_config:
pass # should be captured by model_config.quantization
if self.compilation_config:
Expand Down
1 change: 1 addition & 0 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def encode(
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
truncate_prompt_tokens: int | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model."""
Expand Down
21 changes: 11 additions & 10 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(
self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
enable_caching=self.cache_config.enable_prefix_caching,
enable_caching=bool(self.cache_config.enable_prefix_caching),
use_eagle=self.use_eagle,
log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events,
Expand Down Expand Up @@ -392,7 +392,7 @@ def schedule(self) -> SchedulerOutput:
skipped_waiting_requests.prepend_request(request)
continue

num_external_computed_tokens = 0
num_external_computed_tokens: int | None = 0
load_kv_async = False

# Get already-cached tokens.
Expand All @@ -419,8 +419,8 @@ def schedule(self) -> SchedulerOutput:
continue

# Total computed tokens (local + external).
num_computed_tokens = (
num_new_local_computed_tokens + num_external_computed_tokens
num_computed_tokens = num_new_local_computed_tokens + (
num_external_computed_tokens or 0
)
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
Expand All @@ -434,6 +434,7 @@ def schedule(self) -> SchedulerOutput:

# KVTransfer: loading remote KV, do not allocate for new work.
if load_kv_async:
assert isinstance(num_external_computed_tokens, int)
assert num_external_computed_tokens > 0
num_new_tokens = 0
# Number of tokens to be scheduled.
Expand Down Expand Up @@ -503,7 +504,7 @@ def schedule(self) -> SchedulerOutput:

new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
num_new_tokens + (num_external_computed_tokens or 0),
num_new_local_computed_tokens,
new_computed_blocks,
num_lookahead_tokens=effective_lookahead_tokens,
Expand All @@ -523,7 +524,7 @@ def schedule(self) -> SchedulerOutput:
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_computed_tokens,
num_external_computed_tokens or 0,
)

# Request was already popped from self.waiting
Expand Down Expand Up @@ -916,13 +917,13 @@ def update_from_output(

outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats = (
kv_connector_stats: KVConnectorStats | None = (
kv_connector_output.kv_connector_stats if kv_connector_output else None
)
if kv_connector_stats and self.connector:
stats = self.connector.get_kv_connector_stats()
if stats:
kv_connector_stats = kv_connector_stats.aggregate(stats)
kv_stats = self.connector.get_kv_connector_stats()
if kv_stats:
kv_connector_stats = kv_connector_stats.aggregate(kv_stats)

failed_kv_load_req_ids = None
if kv_connector_output and kv_connector_output.invalid_block_ids:
Expand Down
21 changes: 13 additions & 8 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
from collections.abc import AsyncGenerator, Iterable, Mapping
from copy import copy
from typing import Any
from typing import Any, cast

import numpy as np
import torch
Expand Down Expand Up @@ -131,10 +131,9 @@ def __init__(
self.output_processor = OutputProcessor(
self.tokenizer, log_stats=self.log_stats
)
if self.observability_config.otlp_traces_endpoint is not None:
tracer = init_tracer(
"vllm.llm_engine", self.observability_config.otlp_traces_endpoint
)
endpoint = self.observability_config.otlp_traces_endpoint
if endpoint is not None:
tracer = init_tracer("vllm.llm_engine", endpoint)
self.output_processor.tracer = tracer

# EngineCore (starts the engine in background process).
Expand Down Expand Up @@ -266,7 +265,9 @@ def shutdown(self):
if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown()

cancel_task_threadsafe(getattr(self, "output_handler", None))
handler = getattr(self, "output_handler", None)
if handler is not None:
cancel_task_threadsafe(handler)

async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return await self.engine_core.get_supported_tasks_async()
Expand Down Expand Up @@ -314,7 +315,10 @@ async def add_request(
priority,
data_parallel_rank,
)
prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt")
if isinstance(prompt, str):
prompt_text = prompt
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))

if is_pooling or params.n == 1:
await self._add_request(request, prompt_text, None, 0, queue)
Expand Down Expand Up @@ -436,6 +440,7 @@ async def generate(
# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
finished = out.finished
assert isinstance(out, RequestOutput)
yield out

# If the request is disconnected by the client, generate()
Expand Down Expand Up @@ -653,7 +658,7 @@ async def get_tokenizer(self) -> AnyTokenizer:
return self.tokenizer

async def is_tracing_enabled(self) -> bool:
return self.observability_config.otlp_traces_endpoint is not None
return self.observability_config.otlp_traces_endpoint is not None # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the type need to be ignored here? Should this not always return a bool?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is vllm/v1/engine/async_llm.py:661: error: Item "None" of "ObservabilityConfig | None" has no attribute "otlp_traces_endpoint" [union-attr]

    async def is_tracing_enabled(self) -> bool:
        assert self.observability_config is not None
        return self.observability_config.otlp_traces_endpoint is not None

I think this is not very elegant so just #type ignore


async def do_log_stats(self) -> None:
if self.logger_manager:
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,7 @@ def _init_data_parallel(self, vllm_config: VllmConfig):
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

assert dp_size > 1
assert local_dp_rank is not None
assert 0 <= local_dp_rank <= dp_rank < dp_size

if vllm_config.kv_transfer_config is not None:
Expand Down
14 changes: 7 additions & 7 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,11 @@ def close_sockets_and_tasks():
with contextlib.suppress(Exception):
task.cancel()

if in_loop(loop):
close_sockets_and_tasks()
elif loop and not loop.is_closed():
loop.call_soon_threadsafe(close_sockets_and_tasks)
if loop is not None:
if in_loop(loop):
close_sockets_and_tasks()
elif not loop.is_closed():
loop.call_soon_threadsafe(close_sockets_and_tasks)
else:
# Loop has been closed, try to clean up directly.
del tasks
Expand Down Expand Up @@ -1044,6 +1045,7 @@ def _ensure_stats_update_task(self):
return

assert self.stats_update_address is not None
stats_addr: str = self.stats_update_address
assert len(self.engine_ranks_managed) > 0
# NOTE: running and waiting counts are all global from
# the Coordinator include all global EngineCores. This
Expand All @@ -1054,9 +1056,7 @@ def _ensure_stats_update_task(self):

async def run_engine_stats_update_task():
with (
make_zmq_socket(
self.ctx, self.stats_update_address, zmq.XSUB, linger=0
) as socket,
make_zmq_socket(self.ctx, stats_addr, zmq.XSUB, linger=0) as socket,
make_zmq_socket(
self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=False, linger=0
) as first_req_rcv_socket,
Expand Down
13 changes: 10 additions & 3 deletions vllm/v1/engine/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,21 @@ def __init__(self, request: EngineCoreRequest):
# Stop strings
params = request.sampling_params
assert params is not None
self.stop = stop = params.stop
stop_list: list[str]
if params.stop is None:
stop_list = []
elif isinstance(params.stop, str):
stop_list = [params.stop]
else:
stop_list = params.stop
self.stop = stop_list
self.min_tokens = params.min_tokens
self.include_stop_str_in_output = params.include_stop_str_in_output

# Number of chars to hold back when stop strings are to be excluded
# from streamed output.
if stop and not self.include_stop_str_in_output:
self.stop_buffer_length = max(len(s) for s in stop) - 1
if self.stop and not self.include_stop_str_in_output:
self.stop_buffer_length = max(len(s) for s in self.stop) - 1
else:
self.stop_buffer_length = 0
self._last_output_text_offset: int = 0
Expand Down
16 changes: 9 additions & 7 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
from collections.abc import Callable, Mapping
from copy import copy
from typing import Any
from typing import Any, cast

import torch.nn as nn
from typing_extensions import TypeVar
Expand Down Expand Up @@ -112,10 +112,9 @@ def __init__(
self.output_processor = OutputProcessor(
self.tokenizer, log_stats=self.log_stats
)
if self.observability_config.otlp_traces_endpoint is not None:
tracer = init_tracer(
"vllm.llm_engine", self.observability_config.otlp_traces_endpoint
)
endpoint = self.observability_config.otlp_traces_endpoint
if endpoint is not None:
tracer = init_tracer("vllm.llm_engine", endpoint)
self.output_processor.tracer = tracer

# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
Expand Down Expand Up @@ -259,7 +258,10 @@ def add_request(
trace_headers,
priority,
)
prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt")
if isinstance(prompt, str):
prompt_text = prompt
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))

n = params.n if isinstance(params, SamplingParams) else 1

Expand All @@ -285,7 +287,7 @@ def add_request(
# Add the request to EngineCore.
self.engine_core.add_request(child_request)

def step(self) -> list[RequestOutput] | list[PoolingRequestOutput]:
def step(self) -> list[RequestOutput | PoolingRequestOutput]:
if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False
self.engine_core.execute_dummy_batch()
Expand Down
11 changes: 9 additions & 2 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,14 @@ def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
# This ensures that request outputs with different request indexes
# (if n > 1) do not override each other.
self.output.add(output, aggregate=self.aggregate)
if isinstance(self.output, RequestOutput) and isinstance(
output, RequestOutput
):
self.output.add(output, aggregate=self.aggregate)
elif isinstance(self.output, PoolingRequestOutput) and isinstance(
output, PoolingRequestOutput
):
self.output = output

async def get(self) -> RequestOutput | PoolingRequestOutput:
"""Get operation blocks on put event."""
Expand Down Expand Up @@ -407,7 +414,7 @@ def process_outputs(
within the loop below.
"""

request_outputs: list[RequestOutput] | list[PoolingRequestOutput] = []
request_outputs: list[RequestOutput | PoolingRequestOutput] = []
reqs_to_abort: list[str] = []
for engine_core_output in engine_core_outputs:
req_id = engine_core_output.request_id
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/engine/parallel_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from copy import copy
from typing import Optional
from typing import Optional, cast

from vllm.outputs import CompletionOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(self, request_id: str, sampling_params: SamplingParams) -> None:

self.child_requests = set()
self.output_aggregator = (
[None] * sampling_params.n
[cast(CompletionOutput, None)] * sampling_params.n
if (sampling_params.output_kind == RequestOutputKind.FINAL_ONLY)
else []
)
Expand Down
Loading