2424from paddle .framework import set_default_dtype
2525from paddle .pir_utils import test_with_pir_api
2626
27- np .random .seed (123 )
28- paddle .seed (42 )
29-
3027
3128@unittest .skipIf (
3229 not core .is_compiled_with_cuda ()
@@ -43,11 +40,13 @@ def config(self):
4340 self .batch = 1
4441 self .token = 32
4542 self .in_features = 64
46- self .out_features = 256
43+ self .out_features = 128
4744 self .threshold = 6.0
4845 self .static = False
4946
5047 def setUp (self ):
48+ np .random .seed (123 )
49+ paddle .seed (42 )
5150 self .config ()
5251 x = np .random .random ((self .batch , self .token , self .in_features ))
5352 self .x = paddle .to_tensor (x , dtype = self .dtype )
@@ -64,49 +63,89 @@ def setUp(self):
6463 self .in_features , self .out_features , bias_attr = bias_attr
6564 )
6665
67- self .bias = self .linear .bias
6866 self .weight = self .linear .weight
6967 self .weight_scale = None
7068 self .weight , self .weight_scale = Q .weight_quantize (
7169 self .weight , algo = "llm.int8"
7270 )
7371
72+ def dynamic_quant (self , x ):
73+ row_ranges = paddle .max (x , axis = [- 1 ]).astype ('float32' )
74+ row_ranges = row_ranges .unsqueeze (- 1 )
75+ quant_x = paddle .round (
76+ paddle .clip (
77+ x .astype ('float32' ) * 127.0 * (1 / row_ranges ),
78+ min = - 127.0 ,
79+ max = 127.0 ,
80+ )
81+ ).astype ('int8' )
82+ return quant_x , row_ranges
83+
7484 def get_linear_out (self ):
75- out = self .linear (self .x )
85+ outlier_cols = (
86+ paddle .nonzero (paddle .max (self .x , axis = [0 , 1 ]) > self .threshold )
87+ .reshape ([- 1 ])
88+ .numpy ()
89+ .tolist ()
90+ )
91+
92+ x_int8 = self .x
93+ if len (outlier_cols ) > 0 :
94+ x_fp = self .x [:, :, outlier_cols ]
95+ w_fp = self .linear .weight [outlier_cols ]
96+ res_fp = paddle .matmul (x_fp , w_fp )
97+
98+ x_int8 [:, :, outlier_cols ] = 0
99+ x_int8 , row_ranges = self .dynamic_quant (x_int8 )
100+
101+ res_int8 = paddle .matmul (x_int8 , self .weight .transpose ((1 , 0 )))
102+ dequant_scale = row_ranges * self .weight_scale / 127.0
103+ res_dequant = (res_int8 .astype ('float32' ) * dequant_scale ).astype (
104+ self .dtype
105+ )
106+
107+ if len (outlier_cols ) > 0 :
108+ out = res_dequant + res_fp
109+ else :
110+ out = res_dequant
111+
112+ if self .bias :
113+ out += self .bias
114+
76115 return out .numpy ()
77116
78117 def get_llm_int8_linear_out (self ):
79118 out = Q .llm_int8_linear (
80119 self .x ,
81120 self .weight ,
82- bias = self .bias ,
121+ bias = self .linear . bias ,
83122 weight_scale = self .weight_scale ,
84123 threshold = self .threshold ,
85124 )
86125 return out .numpy ()
87126
88127 @test_with_pir_api
89- def get_llm_int8_linear_out_static (self ):
128+ def llm_int8_linear_out_static (self , out_expect ):
90129 paddle .enable_static ()
91- main = base .static .Program ()
92- start = base .static .Program ()
93- with base .static .program_guard (main , start ):
94- x = paddle .static .data ("x" , self .x .shape , dtype = self .x . dtype )
130+ main = paddle .static .Program ()
131+ start = paddle .static .Program ()
132+ with paddle .static .program_guard (main , start ):
133+ x = paddle .static .data ("x" , self .x .shape , dtype = self .dtype )
95134
96135 weight = paddle .static .data (
97- "weight" , self .weight .shape , dtype = self . weight . dtype
136+ "weight" , self .weight .shape , dtype = 'int8'
98137 )
99138 bias = paddle .static .data (
100- "bias" , self .bias .shape , dtype = self . bias .dtype
139+ "bias" , self .linear . bias .shape , dtype = self .dtype
101140 )
102141 x_np = self .x .numpy ()
103142 weight_np = self .weight .numpy ()
104- bias_np = self .bias .numpy ()
143+ bias_np = self .linear . bias .numpy ()
105144 if self .weight_scale is not None :
106145 weight_scale = paddle .static .data (
107146 "weight_scale" ,
108147 self .weight_scale .shape ,
109- dtype = self . weight_scale . dtype ,
148+ dtype = 'float32' ,
110149 )
111150 weight_scale_np = self .weight_scale .numpy ()
112151 else :
@@ -128,20 +167,30 @@ def get_llm_int8_linear_out_static(self):
128167 }
129168 exe = base .Executor (paddle .CUDAPlace (0 ))
130169 exe .run (start )
131- (out ,) = exe .run (main , feed = feed_dict , fetch_list = [out ])
170+ (out_real ,) = exe .run (main , feed = feed_dict , fetch_list = [out ])
171+
132172 paddle .disable_static ()
133- return out
173+
174+ if self .dtype == "bfloat16" :
175+ out_real = convert_uint16_to_float (out_real )
176+ out_expect = convert_uint16_to_float (out_expect )
177+
178+ np .testing .assert_allclose (
179+ out_real , out_expect , rtol = self .rtol , atol = self .atol
180+ )
134181
135182 def test_llm_int8_linear (self ):
136183 out_expect = self .get_linear_out ()
137184 if self .static :
138- out_real = self .get_llm_int8_linear_out_static ()
185+ self .llm_int8_linear_out_static (out_expect )
186+ return
139187 else :
140188 out_real = self .get_llm_int8_linear_out ()
141189
142190 if self .dtype == "bfloat16" :
143191 out_real = convert_uint16_to_float (out_real )
144192 out_expect = convert_uint16_to_float (out_expect )
193+
145194 np .testing .assert_allclose (
146195 out_real , out_expect , rtol = self .rtol , atol = self .atol
147196 )
@@ -174,19 +223,6 @@ def config(self):
174223 self .weight_dtype = "int8"
175224
176225
177- @unittest .skipIf (
178- not core .is_compiled_with_cuda ()
179- or get_cuda_version () < 11020
180- or paddle .device .cuda .get_device_capability ()[0 ] < 8 ,
181- "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8" ,
182- )
183- class LLMInt8LinearTestCase3 (LLMInt8LinearTestCase ):
184- def config (self ):
185- super ().config ()
186- self .dtype = 'bfloat16'
187- self .weight_dtype = "int8"
188-
189-
190226@unittest .skipIf (
191227 not core .is_compiled_with_cuda ()
192228 or get_cuda_version () < 11020
@@ -215,20 +251,6 @@ def config(self):
215251 self .weight_dtype = "int4"
216252
217253
218- @unittest .skipIf (
219- not core .is_compiled_with_cuda ()
220- or get_cuda_version () < 11020
221- or paddle .device .cuda .get_device_capability ()[0 ] < 8
222- or not core .is_bfloat16_supported (core .CUDAPlace (0 )),
223- "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16" ,
224- )
225- class LLMInt8LinearTestCase6 (LLMInt8LinearTestCase ):
226- def config (self ):
227- super ().config ()
228- self .dtype = 'bfloat16'
229- self .weight_dtype = "int4"
230-
231-
232254@unittest .skipIf (
233255 not core .is_compiled_with_cuda ()
234256 or get_cuda_version () < 11020
@@ -260,21 +282,6 @@ def config(self):
260282 self .token = 1
261283
262284
263- @unittest .skipIf (
264- not core .is_compiled_with_cuda ()
265- or get_cuda_version () < 11020
266- or paddle .device .cuda .get_device_capability ()[0 ] < 8 ,
267- "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8" ,
268- )
269- class LLMInt8LinearTestCase9 (LLMInt8LinearTestCase ):
270- def config (self ):
271- super ().config ()
272- self .dtype = 'bfloat16'
273- self .weight_dtype = "int8"
274- self .batch = 1
275- self .token = 1
276-
277-
278285@unittest .skipIf (
279286 not core .is_compiled_with_cuda ()
280287 or get_cuda_version () < 11020
0 commit comments