File tree Expand file tree Collapse file tree 5 files changed +35
-11
lines changed
examples/offline_inference Expand file tree Collapse file tree 5 files changed +35
-11
lines changed Original file line number Diff line number Diff line change @@ -22,7 +22,8 @@ def main():
2222 # In real workloads, `enforace_eager` should be `False`.
2323 llm = LLM (model = "Qwen/Qwen2-1.5B-Instruct" ,
2424 max_num_batched_tokens = 64 ,
25- max_num_seqs = 4 )
25+ max_num_seqs = 4 ,
26+ max_model_len = 128 )
2627 outputs = llm .generate (prompts , sampling_params )
2728 print ("-" * 50 )
2829 for output , answer in zip (outputs , answers ):
Original file line number Diff line number Diff line change @@ -18,9 +18,9 @@ setuptools==78.1.0
1818--find-links https://storage.googleapis.com/libtpu-releases/index.html
1919--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
2020--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
21- torch==2.8.0.dev20250408
22- torchvision==0.22.0.dev20250408
23- torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408 -cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24- torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408 -cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25- torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408 -cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
21+ torch==2.8.0.dev20250430
22+ torchvision==0.22.0.dev20250430
23+ torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430 -cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24+ torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430 -cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25+ torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430 -cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
2626
Original file line number Diff line number Diff line change @@ -76,9 +76,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
7676 from vllm .config import CompilationLevel
7777
7878 cache_config = vllm_config .cache_config
79+ # For v0, the default block size is 16.
7980 if cache_config and cache_config .block_size is None :
8081 cache_config .block_size = 16
81-
8282 compilation_config = vllm_config .compilation_config
8383
8484 # TPU only supports DYNAMO_ONCE compilation level
@@ -101,16 +101,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
101101 if envs .VLLM_USE_V1 :
102102 from vllm .v1 .attention .backends .pallas import (
103103 PallasAttentionBackend )
104+ cache_config .block_size = PallasAttentionBackend .get_page_size (
105+ vllm_config )
104106 min_page_size = PallasAttentionBackend .get_min_page_size (
105107 vllm_config )
106- if min_page_size > vllm_config . cache_config .block_size :
108+ if min_page_size > cache_config .block_size :
107109 logger .warning (
108110 "Increase the page size from %s to %s to make sure there's"
109111 "no SMEM OOM" ,
110- vllm_config . cache_config .block_size ,
112+ cache_config .block_size ,
111113 min_page_size ,
112114 )
113- vllm_config . cache_config .block_size = min_page_size
115+ cache_config .block_size = min_page_size
114116
115117 parallel_config = vllm_config .parallel_config
116118 scheduler_config = vllm_config .scheduler_config
Original file line number Diff line number Diff line change @@ -708,6 +708,13 @@ def cdiv(a: int, b: int) -> int:
708708 return - (a // - b )
709709
710710
711+ def next_power_of_2 (n ) -> int :
712+ """The next power of 2 (inclusive)"""
713+ if n < 1 :
714+ return 1
715+ return 1 << (n - 1 ).bit_length ()
716+
717+
711718def round_up (x : int , y : int ) -> int :
712719 return ((x + y - 1 ) // y ) * y
713720
Original file line number Diff line number Diff line change 1212from vllm .attention .backends .utils import CommonAttentionState
1313from vllm .config import VllmConfig
1414from vllm .logger import init_logger
15- from vllm .utils import cdiv
15+ from vllm .utils import cdiv , next_power_of_2
1616
1717logger = init_logger (__name__ )
1818
@@ -65,6 +65,20 @@ def get_min_page_size(vllm_config: VllmConfig) -> int:
6565 min_page_size = 1 << (min_page_size - 1 ).bit_length ()
6666 return min_page_size
6767
68+ # TPU has limited SREGs (scalar registers), if page_size is too small, we
69+ # can spill SREGs easily which leads to bad performance. The strategy we
70+ # apply here is trying to split max-model-len to 16 pages which make the
71+ # spill less likely. Meanwhile we make sure the page size is in [16, 256].
72+ @staticmethod
73+ def get_page_size (vllm_config : VllmConfig ) -> int :
74+ page_size = next_power_of_2 (
75+ vllm_config .model_config .max_model_len ) // 16
76+ if page_size <= 16 :
77+ return 16
78+ if page_size >= 256 :
79+ return 256
80+ return page_size
81+
6882
6983@dataclass
7084class PallasMetadata :
You can’t perform that action at this time.
0 commit comments