|
9 | 9 | from vllm.config import CacheConfig |
10 | 10 | from vllm.model_executor.layers.quantization.base_config import ( |
11 | 11 | QuantizationConfig) |
| 12 | +from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod |
12 | 13 |
|
13 | 14 |
|
14 | 15 | class Attention(nn.Module): |
@@ -56,15 +57,19 @@ def __init__( |
56 | 57 | quant_method = quant_config.get_quant_method( |
57 | 58 | self) if quant_config else None |
58 | 59 | if quant_method is not None: |
59 | | - if self.kv_cache_dtype == "fp8_e5m2": |
60 | | - raise ValueError("fp8_e5m2 kv-cache is not supported with " |
61 | | - "fp8 checkpoints.") |
62 | | - # When FP8 quantization is enabled, we make a parameter |
63 | | - # "kv_scale" so that it can be loaded from FP8 checkpoint. |
64 | | - # The kv_scale will then be converted back |
65 | | - # to self._kv_scale in a native float32 value after weight loading. |
66 | | - self.quant_method = quant_method |
67 | | - self.quant_method.create_weights(self) |
| 60 | + assert isinstance(quant_method, Fp8KVCacheMethod) |
| 61 | + # TODO (mgoin): kv cache dtype should be specified in the FP8 |
| 62 | + # checkpoint config and become the "auto" behavior |
| 63 | + if "fp8" in self.kv_cache_dtype: |
| 64 | + if self.kv_cache_dtype == "fp8_e5m2": |
| 65 | + raise ValueError("fp8_e5m2 kv-cache is not supported with " |
| 66 | + "fp8 checkpoints.") |
| 67 | + # When FP8 quantization is enabled, we make a parameter |
| 68 | + # "kv_scale" so that it can be loaded from FP8 checkpoint. |
| 69 | + # The kv_scale will then be converted back to self._kv_scale |
| 70 | + # in a native float32 value after weight loading. |
| 71 | + self.quant_method = quant_method |
| 72 | + self.quant_method.create_weights(self) |
68 | 73 |
|
69 | 74 | # During model initialization, the default dtype is set as the model |
70 | 75 | # weight and activation dtype. |
|
0 commit comments