Skip to content

Commit c803006

Browse files
youkaichaoweilong.yu
authored andcommitted
[platforms] refactor cpu code (vllm-project#10402)
Signed-off-by: youkaichao <[email protected]>
1 parent 052788e commit c803006

File tree

2 files changed

+61
-67
lines changed

2 files changed

+61
-67
lines changed

vllm/executor/cpu_executor.py

Lines changed: 1 addition & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
from functools import partial
33
from typing import Any, Awaitable, List, Optional, Set, Tuple, Union
44

5-
import vllm.envs as envs
6-
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
7-
SchedulerConfig)
85
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
96
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
107
ResultHandler, WorkerMonitor)
@@ -13,7 +10,7 @@
1310
from vllm.model_executor.layers.sampler import SamplerOutput
1411
from vllm.prompt_adapter.request import PromptAdapterRequest
1512
from vllm.sequence import ExecuteModelRequest
16-
from vllm.utils import (GiB_bytes, get_distributed_init_method, get_open_port,
13+
from vllm.utils import (get_distributed_init_method, get_open_port,
1714
get_vllm_instance_id, make_async)
1815
from vllm.worker.worker_base import WorkerWrapperBase
1916

@@ -57,13 +54,6 @@ def _init_executor(self) -> None:
5754
os.environ["LOCAL_WORLD_SIZE"] = str(
5855
self.parallel_config.tensor_parallel_size)
5956

60-
self.model_config = _verify_and_get_model_config(self.model_config)
61-
self.cache_config = _verify_and_get_cache_config(self.cache_config)
62-
self.scheduler_config = _verify_and_get_scheduler_config(
63-
self.scheduler_config)
64-
self.parallel_config = _verify_and_get_parallel_config(
65-
self.parallel_config)
66-
6757
# Multiprocessing-based executor does not support multi-node setting.
6858
# Since it only works for single node, we can use the loopback address
6959
# 127.0.0.1 for communication.
@@ -313,62 +303,6 @@ async def check_health_async(self) -> None:
313303
self.check_health()
314304

315305

316-
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
317-
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
318-
# If the feature combo become valid
319-
if not config.enforce_eager:
320-
logger.warning(
321-
"CUDA graph is not supported on CPU, fallback to the eager "
322-
"mode.")
323-
config.enforce_eager = True
324-
return config
325-
326-
327-
def _verify_and_get_scheduler_config(
328-
config: SchedulerConfig) -> SchedulerConfig:
329-
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
330-
# If the feature combo become valid
331-
if config.chunked_prefill_enabled:
332-
logger.warning("Chunked prefill is not supported on CPU, disable it.")
333-
config.chunked_prefill_enabled = False
334-
335-
return config
336-
337-
338-
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
339-
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
340-
# If the feature combo become valid
341-
if config.enable_prefix_caching:
342-
logger.warning("Prefix caching is not supported on CPU, disable it.")
343-
config.enable_prefix_caching = False
344-
345-
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
346-
347-
if kv_cache_space >= 0:
348-
if kv_cache_space == 0:
349-
config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
350-
logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
351-
"for CPU backend is not set, using 4 by default.")
352-
else:
353-
config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore
354-
else:
355-
raise RuntimeError(
356-
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
357-
f" {kv_cache_space}, expect a positive integer value.")
358-
359-
return config
360-
361-
362-
def _verify_and_get_parallel_config(config: ParallelConfig) -> ParallelConfig:
363-
if (config.distributed_executor_backend is not None
364-
and config.distributed_executor_backend != "mp"):
365-
logger.warning(
366-
"%s is not supported on CPU, fallback to mp distributed executor "
367-
"backend.", config.distributed_executor_backend)
368-
config.distributed_executor_backend = "mp"
369-
return config
370-
371-
372306
def _driver_method_invoker(driver, method: str, *args, **kwargs):
373307
return getattr(driver, method)(*args, **kwargs)
374308

vllm/platforms/cpu.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
1+
from typing import TYPE_CHECKING
2+
13
import psutil
24
import torch
35

6+
from vllm.logger import init_logger
7+
48
from .interface import Platform, PlatformEnum
59

10+
if TYPE_CHECKING:
11+
from vllm.config import VllmConfig
12+
else:
13+
VllmConfig = None
14+
15+
logger = init_logger(__name__)
16+
617

718
class CpuPlatform(Platform):
819
_enum = PlatformEnum.CPU
@@ -18,3 +29,52 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
1829
@classmethod
1930
def inference_mode(cls):
2031
return torch.no_grad()
32+
33+
@classmethod
34+
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
35+
import vllm.envs as envs
36+
from vllm.utils import GiB_bytes
37+
model_config = vllm_config.model_config
38+
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
39+
# If the feature combo become valid
40+
if not model_config.enforce_eager:
41+
logger.warning(
42+
"CUDA graph is not supported on CPU, fallback to the eager "
43+
"mode.")
44+
model_config.enforce_eager = True
45+
46+
cache_config = vllm_config.cache_config
47+
48+
if cache_config.enable_prefix_caching:
49+
logger.warning(
50+
"Prefix caching is not supported on CPU, disable it.")
51+
cache_config.enable_prefix_caching = False
52+
53+
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
54+
55+
if kv_cache_space >= 0:
56+
if kv_cache_space == 0:
57+
cache_config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
58+
logger.warning(
59+
"Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
60+
"for CPU backend is not set, using 4 by default.")
61+
else:
62+
cache_config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore # noqa
63+
else:
64+
raise RuntimeError(
65+
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
66+
f" {kv_cache_space}, expect a positive integer value.")
67+
68+
scheduler_config = vllm_config.scheduler_config
69+
if scheduler_config.chunked_prefill_enabled:
70+
logger.warning(
71+
"Chunked prefill is not supported on CPU, disable it.")
72+
scheduler_config.chunked_prefill_enabled = False
73+
74+
parallel_config = vllm_config.parallel_config
75+
if (parallel_config.distributed_executor_backend is not None
76+
and parallel_config.distributed_executor_backend != "mp"):
77+
logger.warning(("%s is not supported on CPU, fallback to mp "
78+
"distributed executor backend."),
79+
parallel_config.distributed_executor_backend)
80+
parallel_config.distributed_executor_backend = "mp"

0 commit comments

Comments
 (0)