|
4 | 4 | from fractions import Fraction |
5 | 5 | from typing import TYPE_CHECKING, Any |
6 | 6 |
|
| 7 | +import regex as re |
7 | 8 | import torch |
8 | 9 |
|
9 | 10 | from vllm.logger import init_logger |
@@ -128,11 +129,44 @@ def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig": |
128 | 129 |
|
129 | 130 | def get_layer_config(self, layer, layer_name: str): |
130 | 131 | def get_config(name: str, quantized: bool = True): |
131 | | - cfg = self.extra_config.get(name, {}) if self.extra_config else {} |
| 132 | + if not self.extra_config: |
| 133 | + return ( |
| 134 | + self.weight_bits if quantized else 16, |
| 135 | + self.group_size if quantized else -1, |
| 136 | + self.sym if quantized else True, |
| 137 | + ) |
| 138 | + |
| 139 | + # exact match first |
| 140 | + if name in self.extra_config: |
| 141 | + cfg = self.extra_config[name] |
| 142 | + return ( |
| 143 | + cfg.get("bits", self.weight_bits if quantized else 16), |
| 144 | + cfg.get("group_size", self.group_size if quantized else -1), |
| 145 | + cfg.get("sym", self.sym if quantized else True), |
| 146 | + ) |
| 147 | + |
| 148 | + REGEX_SPECIAL_CHARS = set(r"*+?^$()[]{}|\\") |
| 149 | + for pattern, cfg in self.extra_config.items(): |
| 150 | + if not isinstance(pattern, str) or not any( |
| 151 | + c in REGEX_SPECIAL_CHARS for c in pattern |
| 152 | + ): |
| 153 | + continue |
| 154 | + |
| 155 | + try: |
| 156 | + if re.search(re.compile(pattern), name) is not None: |
| 157 | + return ( |
| 158 | + cfg.get("bits", self.weight_bits if quantized else 16), |
| 159 | + cfg.get("group_size", self.group_size if quantized else -1), |
| 160 | + cfg.get("sym", self.sym if quantized else True), |
| 161 | + ) |
| 162 | + except re.error: |
| 163 | + # Invalid regex, ignore. |
| 164 | + continue |
| 165 | + |
132 | 166 | return ( |
133 | | - cfg.get("bits", self.weight_bits if quantized else 16), |
134 | | - cfg.get("group_size", self.group_size if quantized else -1), |
135 | | - cfg.get("sym", self.sym if quantized else True), |
| 167 | + self.weight_bits if quantized else 16, |
| 168 | + self.group_size if quantized else -1, |
| 169 | + self.sym if quantized else True, |
136 | 170 | ) |
137 | 171 |
|
138 | 172 | # 1. Exact match from config |
@@ -176,7 +210,7 @@ def get_config(name: str, quantized: bool = True): |
176 | 210 | f"consistent quant config for {sub_names}" |
177 | 211 | ) |
178 | 212 |
|
179 | | - # 5. Fallback |
| 213 | + # 5. Fallback or try a regular expression match |
180 | 214 | return get_config(layer_name, quantized) |
181 | 215 |
|
182 | 216 | def check_quantized(self, weight_bits: int) -> bool: |
|
0 commit comments