@@ -32,67 +32,20 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
3232 self .out_features = out_features
3333 self .w_bit = w_bit
3434 self .group_size = group_size if group_size != - 1 else in_features
35+
3536 # quick sanity check (make sure aligment)
3637 assert self .in_features % self .group_size == 0
3738 assert out_features % (32 // self .w_bit ) == 0
3839
39- self .register_buffer ('qweight' , torch .zeros ((in_features , out_features // (32 // self .w_bit )), dtype = torch .int32 , device = dev ))
40- self .register_buffer ('qzeros' , torch .zeros ((in_features // self .group_size , out_features // (32 // self .w_bit )), dtype = torch .int32 , device = dev ))
41- self .register_buffer ('scales' , torch .zeros ((in_features // self .group_size , out_features ), dtype = torch .float16 , device = dev ))
40+ self .register_buffer ('qweight' , torch .empty ((in_features , out_features // (32 // self .w_bit )), dtype = torch .int32 , device = dev ))
41+ self .register_buffer ('qzeros' , torch .empty ((in_features // self .group_size , out_features // (32 // self .w_bit )), dtype = torch .int32 , device = dev ))
42+ self .register_buffer ('scales' , torch .empty ((in_features // self .group_size , out_features ), dtype = torch .float16 , device = dev ))
43+
4244 if bias :
43- self .register_buffer ('bias' , torch .zeros ((out_features ), dtype = torch .float16 , device = dev ))
45+ self .register_buffer ('bias' , torch .empty ((out_features ), dtype = torch .float16 , device = dev ))
4446 else :
4547 self .bias = None
4648
47- @classmethod
48- def from_linear (cls , linear , w_bit , group_size , init_only = False , scales = None , zeros = None ):
49- awq_linear = cls (w_bit , group_size , linear .in_features , linear .out_features , linear .bias is not None , linear .weight .device )
50- if init_only : # just prepare for loading sd
51- return awq_linear
52-
53- # need scales and zeros info for real quantization
54- assert scales is not None and zeros is not None
55- scale_zeros = zeros * scales
56-
57- awq_linear .scales = scales .clone ().half ()
58- if linear .bias is not None :
59- awq_linear .bias = linear .bias .clone ().half ()
60-
61- pack_num = 32 // awq_linear .w_bit
62-
63- intweight = []
64- for idx in range (awq_linear .in_features ):
65- intweight .append (torch .round ((linear .weight .data [:, idx ] + scale_zeros [idx // group_size ]) / awq_linear .scales [idx // group_size ]).to (torch .int )[:, None ])
66- intweight = torch .cat (intweight , dim = 1 )
67- intweight = intweight .t ().contiguous ()
68- intweight = intweight .to (dtype = torch .int32 )
69- qweight = torch .zeros ((intweight .shape [0 ], intweight .shape [1 ] // 32 * awq_linear .w_bit ), dtype = torch .int32 , device = intweight .device )
70-
71- for col in range (intweight .shape [1 ] // pack_num ):
72- if awq_linear .w_bit == 4 :
73- order_map = [0 , 2 , 4 , 6 , 1 , 3 , 5 , 7 ]
74- else :
75- raise NotImplementedError ("Only 4-bit are supported for now." )
76- for i in range (pack_num ):
77- qweight_col = intweight [:, col * pack_num + order_map [i ]]
78- qweight [:, col ] |= qweight_col << (i * awq_linear .w_bit )
79- awq_linear .qweight = qweight
80-
81- zeros = zeros .to (dtype = torch .int32 )
82- qzeros = torch .zeros ((zeros .shape [0 ], zeros .shape [1 ] // 32 * awq_linear .w_bit ), dtype = torch .int32 , device = zeros .device )
83-
84- for col in range (zeros .shape [1 ] // pack_num ):
85- if awq_linear .w_bit == 4 :
86- order_map = [0 , 2 , 4 , 6 , 1 , 3 , 5 , 7 ]
87- else :
88- raise NotImplementedError ("Only 4-bit are supported for now." )
89- for i in range (pack_num ):
90- qzero_col = zeros [:, col * pack_num + order_map [i ]]
91- qzeros [:, col ] |= qzero_col << (i * awq_linear .w_bit )
92- awq_linear .qzeros = qzeros
93-
94- return awq_linear
95-
9649 @torch .no_grad ()
9750 def forward (self , x ):
9851 out_shape = x .shape [:- 1 ] + (self .out_features , )
0 commit comments