Skip to content
39 changes: 34 additions & 5 deletions vllm/model_executor/layers/quantization/auto_round.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import re
from fractions import Fraction
from typing import TYPE_CHECKING, Any, Optional, Union

Expand Down Expand Up @@ -120,11 +121,39 @@ def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig":
def get_layer_config(self, layer, layer_name: str):

def get_config(name: str, quantized: bool = True):
cfg = self.extra_config.get(name, {}) if self.extra_config else {}
if not self.extra_config:
return (
self.weight_bits if quantized else 16,
self.group_size if quantized else -1,
self.sym if quantized else True,
)

# exact match first
if name in self.extra_config:
cfg = self.extra_config[name]
return (
cfg.get("bits", self.weight_bits if quantized else 16),
cfg.get("group_size", self.group_size if quantized else -1),
cfg.get("sym", self.sym if quantized else True),
)

# If there is no exact match, try a regular expression match
for pattern, cfg in self.extra_config.items():
try:
if re.fullmatch(pattern, name):
return (
cfg.get("bits", self.weight_bits if quantized else 16),
cfg.get("group_size", self.group_size if quantized else -1),
cfg.get("sym", self.sym if quantized else True),
)
except (re.error, TypeError):
# If the regular expression is invalid or the key is not a string, skip
continue

return (
cfg.get("bits", self.weight_bits if quantized else 16),
cfg.get("group_size", self.group_size if quantized else -1),
cfg.get("sym", self.sym if quantized else True),
self.weight_bits if quantized else 16,
self.group_size if quantized else -1,
self.sym if quantized else True,
)

# 1. Exact match from config
Expand Down Expand Up @@ -169,7 +198,7 @@ def get_config(name: str, quantized: bool = True):
f"Fused module '{layer_name}' requires "
f"consistent quant config for {sub_names}")

# 5. Fallback
# 5. Fallback or try a regular expression match
return get_config(layer_name, quantized)

def check_quantized(self, weight_bits: int) -> bool:
Expand Down
Loading