Skip to content

Commit bdcd016

Browse files
rkooo567joerunde
authored andcommitted
[mypy] Add mypy type annotation part 1 (vllm-project#4006)
1 parent 63c2316 commit bdcd016

25 files changed

Lines changed: 171 additions & 72 deletions

.github/workflows/mypy.yaml

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
name: mypy
2+
3+
on:
4+
# Trigger the workflow on push or pull request,
5+
# but only for the main branch
6+
push:
7+
branches:
8+
- main
9+
pull_request:
10+
branches:
11+
- main
12+
13+
jobs:
14+
ruff:
15+
runs-on: ubuntu-latest
16+
strategy:
17+
matrix:
18+
python-version: ["3.8"]
19+
steps:
20+
- uses: actions/checkout@v2
21+
- name: Set up Python ${{ matrix.python-version }}
22+
uses: actions/setup-python@v2
23+
with:
24+
python-version: ${{ matrix.python-version }}
25+
- name: Install dependencies
26+
run: |
27+
python -m pip install --upgrade pip
28+
pip install mypy==1.9.0
29+
pip install types-setuptools
30+
pip install types-PyYAML
31+
pip install types-requests
32+
pip install types-setuptools
33+
- name: Mypy
34+
run: |
35+
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
36+
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
37+
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
38+
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
39+
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
40+
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
41+
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
42+
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
43+
44+
# TODO(sang): Follow up
45+
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
46+
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
47+
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
48+
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
49+
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
50+

format.sh

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,23 @@ fi
9393
echo 'vLLM yapf: Done'
9494

9595
# Run mypy
96-
# TODO(zhuohan): Enable mypy
97-
# echo 'vLLM mypy:'
98-
# mypy
96+
echo 'vLLM mypy:'
97+
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
98+
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
99+
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
100+
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
101+
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
102+
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
103+
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
104+
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
105+
106+
# TODO(sang): Follow up
107+
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
108+
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
109+
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
110+
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
111+
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
112+
99113

100114
CODESPELL_EXCLUDES=(
101115
'--skip' '*docs/source/_build/**'
@@ -228,5 +242,3 @@ if ! git diff --quiet &>/dev/null; then
228242

229243
exit 1
230244
fi
231-
232-

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,13 @@ ignore = [
4646
python_version = "3.8"
4747

4848
ignore_missing_imports = true
49+
check_untyped_defs = true
4950

5051
files = "vllm"
5152
# TODO(woosuk): Include the code from Megatron and HuggingFace.
52-
exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/"
53+
exclude = [
54+
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
55+
]
5356

5457

5558
[tool.codespell]

requirements-common.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ uvicorn[standard]
1111
pydantic >= 2.0 # Required for OpenAI server.
1212
prometheus_client >= 0.18.0
1313
tiktoken == 0.6.0 # Required for DBRX tokenizer
14-
outlines == 0.0.34 # Requires torch >= 2.1.0
14+
outlines == 0.0.34 # Requires torch >= 2.1.0
15+
typing_extensions

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ codespell==2.2.6
77
isort==5.13.2
88

99
# type checking
10-
mypy==0.991
10+
mypy==1.9.0
1111
types-PyYAML
1212
types-requests
1313
types-setuptools

vllm/config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import os
44
from dataclasses import dataclass, fields
5-
from typing import TYPE_CHECKING, ClassVar, Optional, Union
5+
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
66

77
import torch
88
from packaging.version import Version
@@ -141,7 +141,7 @@ def _verify_load_format(self) -> None:
141141
supported_load_format = [
142142
"auto", "pt", "safetensors", "npcache", "dummy"
143143
]
144-
rocm_not_supported_load_format = []
144+
rocm_not_supported_load_format: List[str] = []
145145
if load_format not in supported_load_format:
146146
raise ValueError(
147147
f"Unknown load format: {self.load_format}. Must be one of "
@@ -679,6 +679,9 @@ def maybe_create_spec_config(
679679
"num_speculative_tokens to be provided, but found "
680680
f"{speculative_model=} and {num_speculative_tokens=}.")
681681

682+
assert (speculative_model is not None
683+
and num_speculative_tokens is not None)
684+
682685
# TODO: The user should be able to specify revision/quantization/max
683686
# model len for the draft model. It is not currently supported.
684687
draft_revision = None
@@ -993,7 +996,7 @@ def _get_and_verify_max_len(
993996
derived_max_model_len *= scaling_factor
994997

995998
if max_model_len is None:
996-
max_model_len = derived_max_model_len
999+
max_model_len = int(derived_max_model_len)
9971000
elif max_model_len > derived_max_model_len:
9981001
# Some models might have a separate key for specifying model_max_length
9991002
# that will be bigger than derived_max_model_len. We compare user input

vllm/core/block_manager_v1.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""A block manager that manages token blocks."""
22
from abc import ABC, abstractmethod
3+
from collections.abc import Sequence as GenericSequence
34
from itertools import count, takewhile
45
from os.path import commonprefix
56
from typing import Dict, List, Optional, Set
@@ -231,10 +232,10 @@ def __init__(
231232

232233
if self.enable_caching:
233234
logger.info("Automatic prefix caching is enabled.")
234-
self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size,
235-
num_gpu_blocks)
236-
self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size,
237-
num_cpu_blocks)
235+
self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
236+
Device.GPU, block_size, num_gpu_blocks)
237+
self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
238+
Device.CPU, block_size, num_cpu_blocks)
238239
else:
239240
self.gpu_allocator = UncachedBlockAllocator(
240241
Device.GPU, block_size, num_gpu_blocks)
@@ -588,7 +589,8 @@ def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
588589
for b in takewhile(lambda b: b.computed, block_table[:-1])
589590
]
590591

591-
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
592+
def get_common_computed_block_ids(
593+
self, seqs: List[Sequence]) -> GenericSequence[int]:
592594
"""Return the block ids that are common for a given sequence group.
593595
594596
Used in prefill (can skip prefill of some blocks).

vllm/core/block_manager_v2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""A block manager that manages token blocks."""
2+
from collections.abc import Sequence as GenericSequence
23
from typing import Dict, List, Optional
34

45
from vllm.core.block.block_table import BlockTable
@@ -205,7 +206,8 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup):
205206
# as computed.
206207
self.block_allocator.mark_blocks_as_computed()
207208

208-
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
209+
def get_common_computed_block_ids(
210+
self, seqs: List[Sequence]) -> GenericSequence[int]:
209211
"""Determine which blocks for which we skip prefill.
210212
211213
With prefix caching we can skip prefill for previously-generated blocks.

vllm/core/interfaces.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import enum
22
from abc import ABC, abstractmethod
3+
from collections.abc import Sequence as GenericSequence
34
from typing import Dict, List
45

56
from vllm.sequence import Sequence, SequenceGroup
@@ -103,7 +104,8 @@ def access_all_blocks_in_seq(
103104
pass
104105

105106
@abstractmethod
106-
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
107+
def get_common_computed_block_ids(
108+
self, seqs: List[Sequence]) -> GenericSequence[int]:
107109
pass
108110

109111
@abstractmethod

vllm/core/scheduler.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ class SchedulingBudget:
4242
"""
4343
token_budget: int
4444
max_num_seqs: int
45-
_requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set)
46-
_requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set)
45+
_requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set)
46+
_requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set)
4747
_num_batched_tokens: int = 0
4848
_num_curr_seqs: int = 0
4949

@@ -133,7 +133,7 @@ def is_empty(self) -> bool:
133133
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
134134
and not self.blocks_to_swap_out and not self.blocks_to_copy)
135135

136-
def _sort_by_lora_ids(self) -> bool:
136+
def _sort_by_lora_ids(self):
137137
self.scheduled_seq_groups = sorted(
138138
self.scheduled_seq_groups,
139139
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
@@ -337,7 +337,8 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
337337
self.free_seq(seq)
338338

339339
def has_unfinished_seqs(self) -> bool:
340-
return self.waiting or self.running or self.swapped
340+
return len(self.waiting) != 0 or len(self.running) != 0 or len(
341+
self.swapped) != 0
341342

342343
def get_num_unfinished_seq_groups(self) -> int:
343344
return len(self.waiting) + len(self.running) + len(self.swapped)
@@ -404,7 +405,7 @@ def _schedule_running(
404405
budget.subtract_num_seqs(seq_group.request_id,
405406
num_running_seqs)
406407
if curr_loras is not None and seq_group.lora_int_id > 0:
407-
curr_loras.pop(seq_group.lora_int_id)
408+
curr_loras.remove(seq_group.lora_int_id)
408409

409410
if running_queue:
410411
# Preempt the lowest-priority sequence groups.
@@ -496,7 +497,7 @@ def _schedule_swapped(
496497
now = time.time()
497498
swapped_queue = policy.sort_by_priority(now, swapped_queue)
498499

499-
leftover_swapped = deque()
500+
leftover_swapped: Deque[SequenceGroup] = deque()
500501
while swapped_queue:
501502
seq_group = swapped_queue[0]
502503

@@ -507,7 +508,9 @@ def _schedule_swapped(
507508
lora_int_id = 0
508509
if self.lora_enabled:
509510
lora_int_id = seq_group.lora_int_id
510-
if (lora_int_id > 0 and lora_int_id not in curr_loras
511+
assert curr_loras is not None
512+
assert self.lora_config is not None
513+
if (lora_int_id > 0 and (lora_int_id not in curr_loras)
511514
and len(curr_loras) >= self.lora_config.max_loras):
512515
# We don't have a space for another LoRA, so
513516
# we ignore this request for now.
@@ -593,7 +596,7 @@ def _schedule_prefills(
593596
# Copy the queue so that the input queue is not modified.
594597
waiting_queue = deque([s for s in waiting_queue])
595598

596-
leftover_waiting_sequences = deque()
599+
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
597600
while self._passed_delay(time.time()) and waiting_queue:
598601
seq_group = waiting_queue[0]
599602

@@ -635,6 +638,8 @@ def _schedule_prefills(
635638
lora_int_id = 0
636639
if self.lora_enabled:
637640
lora_int_id = seq_group.lora_int_id
641+
assert curr_loras is not None
642+
assert self.lora_config is not None
638643
if (self.lora_enabled and lora_int_id > 0
639644
and lora_int_id not in curr_loras
640645
and len(curr_loras) >= self.lora_config.max_loras):
@@ -780,7 +785,7 @@ def _schedule_chunked_prefill(self):
780785
token_budget=self.scheduler_config.max_num_batched_tokens,
781786
max_num_seqs=self.scheduler_config.max_num_seqs,
782787
)
783-
curr_loras = set()
788+
curr_loras: Set[int] = set()
784789

785790
remaining_waiting, prefills = (self.waiting,
786791
SchedulerPrefillOutputs.create_empty())
@@ -1087,7 +1092,7 @@ def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
10871092

10881093
def _get_num_new_tokens(self, seq_group: SequenceGroup,
10891094
status: SequenceStatus, enable_chunking: bool,
1090-
budget: SchedulingBudget) -> Tuple[int, bool]:
1095+
budget: SchedulingBudget) -> int:
10911096
"""Get the next new tokens to compute for a given sequence group
10921097
that's in a given `status`.
10931098

0 commit comments

Comments
 (0)