1111import numpy as np
1212from safetensors .torch import load_file , save_file , safe_open
1313import torch
14- from transformers import PretrainedConfig
1514from tqdm .auto import tqdm
1615
16+ from vllm .config import ModelConfig
1717from vllm .logger import init_logger
1818from vllm .model_executor .layers .quantization import (get_quantization_config ,
1919 QuantizationConfig )
@@ -102,25 +102,22 @@ def get_sparse_config(
102102
103103
104104# TODO(woosuk): Move this to other place.
105- def get_quant_config (
106- quantization : str ,
107- model_name_or_path : str ,
108- hf_config : PretrainedConfig ,
109- cache_dir : Optional [str ] = None ,
110- ) -> QuantizationConfig :
111- quant_cls = get_quantization_config (quantization )
105+ def get_quant_config (model_config : ModelConfig ) -> QuantizationConfig :
106+ quant_cls = get_quantization_config (model_config .quantization )
112107 # Read the quantization config from the HF model config, if available.
113- hf_quant_config = getattr (hf_config , "quantization_config" , None )
108+ hf_quant_config = getattr (model_config .hf_config , "quantization_config" ,
109+ None )
114110 if hf_quant_config is not None :
115111 return quant_cls .from_config (hf_quant_config )
116-
112+ model_name_or_path = model_config . model
117113 is_local = os .path .isdir (model_name_or_path )
118114 if not is_local :
119115 # Download the config files.
120- with get_lock (model_name_or_path , cache_dir ):
116+ with get_lock (model_name_or_path , model_config . download_dir ):
121117 hf_folder = snapshot_download (model_name_or_path ,
118+ revision = model_config .revision ,
122119 allow_patterns = "*.json" ,
123- cache_dir = cache_dir ,
120+ cache_dir = model_config . download_dir ,
124121 tqdm_class = Disabledtqdm )
125122 else :
126123 hf_folder = model_name_or_path
@@ -131,10 +128,12 @@ def get_quant_config(
131128 f .endswith (x ) for x in quant_cls .get_config_filenames ())
132129 ]
133130 if len (quant_config_files ) == 0 :
134- raise ValueError (f"Cannot find the config file for { quantization } " )
131+ raise ValueError (
132+ f"Cannot find the config file for { model_config .quantization } " )
135133 if len (quant_config_files ) > 1 :
136- raise ValueError (f"Found multiple config files for { quantization } : "
137- f"{ quant_config_files } " )
134+ raise ValueError (
135+ f"Found multiple config files for { model_config .quantization } : "
136+ f"{ quant_config_files } " )
138137
139138 quant_config_file = quant_config_files [0 ]
140139 with open (quant_config_file , "r" ) as f :
0 commit comments