Skip to content
44 changes: 39 additions & 5 deletions vllm/model_executor/layers/quantization/auto_round.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fractions import Fraction
from typing import TYPE_CHECKING, Any

import regex as re
import torch

from vllm.logger import init_logger
Expand Down Expand Up @@ -128,11 +129,44 @@ def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig":

def get_layer_config(self, layer, layer_name: str):
def get_config(name: str, quantized: bool = True):
cfg = self.extra_config.get(name, {}) if self.extra_config else {}
if not self.extra_config:
return (
self.weight_bits if quantized else 16,
self.group_size if quantized else -1,
self.sym if quantized else True,
)

# exact match first
if name in self.extra_config:
cfg = self.extra_config[name]
return (
cfg.get("bits", self.weight_bits if quantized else 16),
cfg.get("group_size", self.group_size if quantized else -1),
cfg.get("sym", self.sym if quantized else True),
)

REGEX_SPECIAL_CHARS = set(r"*+?^$()[]{}|\\")
for pattern, cfg in self.extra_config.items():
if not isinstance(pattern, str) or not any(
c in REGEX_SPECIAL_CHARS for c in pattern
):
continue

try:
if re.search(re.compile(pattern), name) is not None:
return (
cfg.get("bits", self.weight_bits if quantized else 16),
cfg.get("group_size", self.group_size if quantized else -1),
cfg.get("sym", self.sym if quantized else True),
)
except re.error:
# Invalid regex, ignore.
continue

return (
cfg.get("bits", self.weight_bits if quantized else 16),
cfg.get("group_size", self.group_size if quantized else -1),
cfg.get("sym", self.sym if quantized else True),
self.weight_bits if quantized else 16,
self.group_size if quantized else -1,
self.sym if quantized else True,
)

# 1. Exact match from config
Expand Down Expand Up @@ -176,7 +210,7 @@ def get_config(name: str, quantized: bool = True):
f"consistent quant config for {sub_names}"
)

# 5. Fallback
# 5. Fallback or try a regular expression match
return get_config(layer_name, quantized)

def check_quantized(self, weight_bits: int) -> bool:
Expand Down