|
3 | 3 | import torch |
4 | 4 | from transformers import AutoConfig, PretrainedConfig |
5 | 5 |
|
| 6 | +from cacheflow.logger import init_logger |
| 7 | +from cacheflow.utils import get_cpu_memory |
| 8 | + |
| 9 | +logger = init_logger(__name__) |
| 10 | + |
6 | 11 | _GiB = 1 << 30 |
7 | 12 |
|
8 | 13 |
|
@@ -73,11 +78,37 @@ def __init__( |
73 | 78 | self.block_size = block_size |
74 | 79 | self.gpu_memory_utilization = gpu_memory_utilization |
75 | 80 | self.swap_space_bytes = swap_space * _GiB |
| 81 | + self._verify_args() |
76 | 82 |
|
77 | 83 | # Will be set after profiling. |
78 | 84 | self.num_gpu_blocks = None |
79 | 85 | self.num_cpu_blocks = None |
80 | 86 |
|
| 87 | + def _verify_args(self) -> None: |
| 88 | + if self.gpu_memory_utilization > 1.0: |
| 89 | + raise ValueError( |
| 90 | + "GPU memory utilization must be less than 1.0. Got " |
| 91 | + f"{self.gpu_memory_utilization}.") |
| 92 | + |
| 93 | + def verify_with_parallel_config( |
| 94 | + self, |
| 95 | + parallel_config: "ParallelConfig", |
| 96 | + ) -> None: |
| 97 | + total_cpu_memory = get_cpu_memory() |
| 98 | + # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel |
| 99 | + # group are in the same node. However, the GPUs may span multiple nodes. |
| 100 | + num_gpus_per_node = parallel_config.tensor_parallel_size |
| 101 | + cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node |
| 102 | + |
| 103 | + msg = ( |
| 104 | + f"{cpu_memory_usage / _GiB:.2f} GiB out of " |
| 105 | + f"the {total_cpu_memory / _GiB:.2f} GiB total CPU memory is " |
| 106 | + "allocated for the swap space.") |
| 107 | + if cpu_memory_usage > 0.7 * total_cpu_memory: |
| 108 | + raise ValueError("Too large swap space. " + msg) |
| 109 | + elif cpu_memory_usage > 0.4 * total_cpu_memory: |
| 110 | + logger.warn("Possibly too large swap space. " + msg) |
| 111 | + |
81 | 112 |
|
82 | 113 | class ParallelConfig: |
83 | 114 |
|
|
0 commit comments