77import modules .shared as shared
88
99sys .path .insert (0 , str (Path ("repositories/GPTQ-for-LLaMa" )))
10- from llama import load_quant
1110
1211
1312# 4-bit LLaMA
14- def load_quantized_LLaMA (model_name ):
15- if shared .args .load_in_4bit :
16- bits = 4
13+ def load_quant (model_name , model_type ):
14+ if model_type == 'llama' :
15+ from llama import load_quant
16+ elif model_type == 'opt' :
17+ from opt import load_quant
1718 else :
18- bits = shared .args .gptq_bits
19+ print ("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported" )
20+ exit ()
1921
2022 path_to_model = Path (f'models/{ model_name } ' )
21- pt_model = ''
22- if path_to_model .name .lower ().startswith ('llama-7b' ):
23- pt_model = f'llama-7b-{ bits } bit.pt'
24- elif path_to_model .name .lower ().startswith ('llama-13b' ):
25- pt_model = f'llama-13b-{ bits } bit.pt'
26- elif path_to_model .name .lower ().startswith ('llama-30b' ):
27- pt_model = f'llama-30b-{ bits } bit.pt'
28- elif path_to_model .name .lower ().startswith ('llama-65b' ):
29- pt_model = f'llama-65b-{ bits } bit.pt'
30- else :
31- pt_model = f'{ model_name } -{ bits } bit.pt'
23+ pt_model = f'{ model_name } -{ shared .args .gptq_bits } bit.pt'
3224
3325 # Try to find the .pt both in models/ and in the subfolder
3426 pt_path = None
@@ -40,7 +32,7 @@ def load_quantized_LLaMA(model_name):
4032 print (f"Could not find { pt_model } , exiting..." )
4133 exit ()
4234
43- model = load_quant (path_to_model , str (pt_path ), bits )
35+ model = load_quant (path_to_model , str (pt_path ), shared . args . gptq_bits )
4436
4537 # Multiple GPUs or GPU+CPU
4638 if shared .args .gpu_memory :
0 commit comments