Skip to content

Commit aedba6d

Browse files
authored
Print warnings/errors for large swap space (#123)
1 parent a283ec2 commit aedba6d

File tree

3 files changed

+34
-0
lines changed

3 files changed

+34
-0
lines changed

cacheflow/config.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
import torch
44
from transformers import AutoConfig, PretrainedConfig
55

6+
from cacheflow.logger import init_logger
7+
from cacheflow.utils import get_cpu_memory
8+
9+
logger = init_logger(__name__)
10+
611
_GiB = 1 << 30
712

813

@@ -73,11 +78,37 @@ def __init__(
7378
self.block_size = block_size
7479
self.gpu_memory_utilization = gpu_memory_utilization
7580
self.swap_space_bytes = swap_space * _GiB
81+
self._verify_args()
7682

7783
# Will be set after profiling.
7884
self.num_gpu_blocks = None
7985
self.num_cpu_blocks = None
8086

87+
def _verify_args(self) -> None:
88+
if self.gpu_memory_utilization > 1.0:
89+
raise ValueError(
90+
"GPU memory utilization must be less than 1.0. Got "
91+
f"{self.gpu_memory_utilization}.")
92+
93+
def verify_with_parallel_config(
94+
self,
95+
parallel_config: "ParallelConfig",
96+
) -> None:
97+
total_cpu_memory = get_cpu_memory()
98+
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
99+
# group are in the same node. However, the GPUs may span multiple nodes.
100+
num_gpus_per_node = parallel_config.tensor_parallel_size
101+
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
102+
103+
msg = (
104+
f"{cpu_memory_usage / _GiB:.2f} GiB out of "
105+
f"the {total_cpu_memory / _GiB:.2f} GiB total CPU memory is "
106+
"allocated for the swap space.")
107+
if cpu_memory_usage > 0.7 * total_cpu_memory:
108+
raise ValueError("Too large swap space. " + msg)
109+
elif cpu_memory_usage > 0.4 * total_cpu_memory:
110+
logger.warn("Possibly too large swap space. " + msg)
111+
81112

82113
class ParallelConfig:
83114

cacheflow/server/llm_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(
8484

8585
def _verify_args(self) -> None:
8686
self.model_config.verify_with_parallel_config(self.parallel_config)
87+
self.cache_config.verify_with_parallel_config(self.parallel_config)
8788

8889
def _init_cache(self) -> None:
8990
# Get the maximum number of blocks that can be allocated on GPU and CPU.

cacheflow/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@ def reset(self) -> None:
2424

2525

2626
def get_gpu_memory(gpu: int = 0) -> int:
27+
"""Returns the total memory of the GPU in bytes."""
2728
return torch.cuda.get_device_properties(gpu).total_memory
2829

2930

3031
def get_cpu_memory() -> int:
32+
"""Returns the total CPU memory of the node in bytes."""
3133
return psutil.virtual_memory().total

0 commit comments

Comments
 (0)