@@ -48,18 +48,33 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
4848 requires_grad = False )
4949 # INPUT SCALE
5050 if self .is_static_input_scheme :
51- layer .input_scale = Parameter (layer .input_scale .max (),
52- requires_grad = False )
53- if not self .input_symmetric :
54- layer .input_zero_point = Parameter (layer .input_zero_point ,
55- requires_grad = False )
51+ if self .input_symmetric :
52+ layer .input_scale = Parameter (layer .input_scale .max (),
53+ requires_grad = False )
5654 else :
57- layer .input_zero_point = None
55+ raise NotImplementedError (
56+ "static input asymmetric quantization not supported yet" )
57+ # reconstruct the ranges
58+ int8_traits = torch .iinfo (torch .int8 )
59+ range_max = (layer .input_scale *
60+ (int8_traits .max - layer .input_zero_point )).max ()
61+ range_min = (layer .input_scale *
62+ (int8_traits .min - layer .input_zero_point )).min ()
63+
64+ scale = (range_max - range_min ) / (int8_traits .max -
65+ int8_traits .min )
66+ layer .input_scale = Parameter (scale , requires_grad = False )
67+
68+ azp = int8_traits .min - range_min / scale
69+ layer .input_zero_point = Parameter (azp , requires_grad = False )
70+
5871 else :
5972 layer .input_scale = None
6073 layer .input_zero_point = None
6174
6275 if not self .input_symmetric :
76+ # azp_adj is the AZP adjustment term, used to account for weights.
77+ # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
6378 layer .azp_adj = layer .weight .sum (dim = 0 ,
6479 keepdim = True ,
6580 dtype = torch .int32 )
@@ -108,7 +123,7 @@ def create_weights(self, layer: torch.nn.Module,
108123 if not self .input_symmetric :
109124 raise NotImplementedError (
110125 "static input asymmetric quantization not supported yet" )
111- input_zero_point = Parameter (torch .zeros (1 , dtype = torch .int8 ))
126+ input_zero_point = Parameter (torch .zeros (1 , dtype = torch .int32 ))
112127 layer .register_parameter ("input_zero_point" , input_zero_point )
113128
114129 def apply_weights (self , layer : torch .nn .Module , x : torch .Tensor ,
0 commit comments