88import triton
99import triton .language as tl
1010
11+ from vllm import _custom_ops as ops
1112from vllm .logger import init_logger
1213from vllm .platforms import current_platform
1314
@@ -21,20 +22,34 @@ def apply_w8a8_block_fp8_linear(
2122 weight_scale : torch .Tensor ,
2223 input_scale : Optional [torch .Tensor ] = None ,
2324 bias : Optional [torch .Tensor ] = None ,
25+ cutlass_block_fp8_supported : bool = True ,
2426) -> torch .Tensor :
2527 assert input_scale is None
2628 # View input as 2D matrix for fp8 methods
2729 input_2d = input .view (- 1 , input .shape [- 1 ])
2830 output_shape = [* input .shape [:- 1 ], weight .shape [0 ]]
2931
30- q_input , x_scale = per_token_group_quant_fp8 (input_2d , block_size [1 ])
31- output = w8a8_block_fp8_matmul (q_input ,
32- weight ,
33- x_scale ,
34- weight_scale ,
35- block_size ,
36- output_dtype = input .dtype )
37-
32+ shape_supported_by_cutlass = (weight .shape [0 ] % 128 == 0
33+ and weight .shape [1 ] % 128 == 0 )
34+ if cutlass_block_fp8_supported and shape_supported_by_cutlass :
35+ q_input , x_scale = per_token_group_quant_fp8 (input_2d ,
36+ block_size [1 ],
37+ column_major_scales = True )
38+ output = ops .cutlass_scaled_mm (q_input ,
39+ weight .T ,
40+ out_dtype = input .dtype ,
41+ scale_a = x_scale ,
42+ scale_b = weight_scale .T )
43+ else :
44+ q_input , x_scale = per_token_group_quant_fp8 (input_2d ,
45+ block_size [1 ],
46+ column_major_scales = False )
47+ output = w8a8_block_fp8_matmul (q_input ,
48+ weight ,
49+ x_scale ,
50+ weight_scale ,
51+ block_size ,
52+ output_dtype = input .dtype )
3853 if bias is not None :
3954 output = output + bias
4055 return output .to (dtype = input .dtype ).view (* output_shape )
@@ -98,10 +113,7 @@ def _per_token_group_quant_fp8(
98113 y_ptr ,
99114 y_q_ptr ,
100115 y_s_ptr ,
101- # Stride of input
102- y_stride ,
103- # Columns of input
104- N ,
116+ group_size ,
105117 # Avoid to divide zero
106118 eps ,
107119 # Information for float8
@@ -116,12 +128,60 @@ def _per_token_group_quant_fp8(
116128 """
117129 # Map the program id to the row of X and Y it should compute.
118130 g_id = tl .program_id (0 )
119- y_ptr += g_id * y_stride
120- y_q_ptr += g_id * y_stride
131+ y_ptr += g_id * group_size
132+ y_q_ptr += g_id * group_size
121133 y_s_ptr += g_id
122134
123135 cols = tl .arange (0 , BLOCK ) # N <= BLOCK
124- mask = cols < N
136+ mask = cols < group_size
137+
138+ y = tl .load (y_ptr + cols , mask = mask , other = 0.0 ).to (tl .float32 )
139+ # Quant
140+ _absmax = tl .maximum (tl .max (tl .abs (y )), eps )
141+ y_s = _absmax / fp8_max
142+ y_q = tl .clamp (y / y_s , fp8_min , fp8_max ).to (y_q_ptr .dtype .element_ty )
143+
144+ tl .store (y_q_ptr + cols , y_q , mask = mask )
145+ tl .store (y_s_ptr , y_s )
146+
147+
148+ @triton .jit
149+ def _per_token_group_quant_fp8_colmajor (
150+ # Pointers to inputs and output
151+ y_ptr ,
152+ y_q_ptr ,
153+ y_s_ptr ,
154+ group_size ,
155+ # Num columns of y
156+ y_num_columns ,
157+ # Stride from one column to the next of y_s
158+ y_s_col_stride ,
159+ # Avoid to divide zero
160+ eps ,
161+ # Information for float8
162+ fp8_min ,
163+ fp8_max ,
164+ # Meta-parameters
165+ BLOCK : tl .constexpr ,
166+ ):
167+ """A Triton-accelerated function to perform per-token-group
168+ quantization on a tensor.
169+ This function converts the tensor values into float8 values.
170+ """
171+ # Map the program id to the row of X and Y it should compute.
172+ g_id = tl .program_id (0 )
173+ y_ptr += g_id * group_size
174+ y_q_ptr += g_id * group_size
175+
176+ # Convert g_id the flattened block coordinate to 2D so we can index
177+ # into the output y_scales matrix
178+ blocks_per_row = y_num_columns // group_size
179+ scale_col = g_id % blocks_per_row
180+ scale_row = g_id // blocks_per_row
181+ y_s_ptr += scale_col * y_s_col_stride + scale_row
182+
183+ cols = tl .arange (0 , BLOCK ) # group_size <= BLOCK
184+ mask = cols < group_size
125185
126186 y = tl .load (y_ptr + cols , mask = mask , other = 0.0 ).to (tl .float32 )
127187 # Quant
@@ -138,12 +198,13 @@ def per_token_group_quant_fp8(
138198 group_size : int ,
139199 eps : float = 1e-10 ,
140200 dtype : Optional [torch .dtype ] = None ,
201+ column_major_scales : bool = False ,
141202) -> Tuple [torch .Tensor , torch .Tensor ]:
142203 """Function to perform per-token-group quantization on an input tensor `x`.
143204 It converts the tensor values into signed float8 values and returns the
144205 quantized tensor along with the scaling factor used for quantization.
145206 Args:
146- x: The input tenosr with ndim >= 2.
207+ x: The input tensor with ndim >= 2.
147208 group_size: The group size used for quantization.
148209 eps: The minimum to avoid dividing zero.
149210 dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
@@ -167,29 +228,46 @@ def per_token_group_quant_fp8(
167228 x_q = torch .empty_like (x , device = x .device , dtype = dtype )
168229 M = x .numel () // group_size
169230 N = group_size
170- x_s = torch .empty (
171- x .shape [:- 1 ] + (x .shape [- 1 ] // group_size , ),
172- device = x .device ,
173- dtype = torch .float32 ,
174- )
231+ if column_major_scales :
232+ shape = (x .shape [- 1 ] // group_size , ) + x .shape [:- 1 ]
233+ x_s = torch .empty (shape , device = x .device ,
234+ dtype = torch .float32 ).permute (- 1 , - 2 )
235+ else :
236+ shape = x .shape [:- 1 ] + (x .shape [- 1 ] // group_size , )
237+ x_s = torch .empty (shape , device = x .device , dtype = torch .float32 )
175238
176239 BLOCK = triton .next_power_of_2 (N )
177240 # heuristics for number of warps
178241 num_warps = min (max (BLOCK // 256 , 1 ), 8 )
179242 num_stages = 1
180- _per_token_group_quant_fp8 [(M , )](
181- x ,
182- x_q ,
183- x_s ,
184- group_size ,
185- N ,
186- eps ,
187- fp8_min = fp8_min ,
188- fp8_max = fp8_max ,
189- BLOCK = BLOCK ,
190- num_warps = num_warps ,
191- num_stages = num_stages ,
192- )
243+ if column_major_scales :
244+ _per_token_group_quant_fp8_colmajor [(M , )](
245+ x ,
246+ x_q ,
247+ x_s ,
248+ group_size ,
249+ x .shape [1 ],
250+ x_s .stride (1 ),
251+ eps ,
252+ fp8_min = fp8_min ,
253+ fp8_max = fp8_max ,
254+ BLOCK = BLOCK ,
255+ num_warps = num_warps ,
256+ num_stages = num_stages ,
257+ )
258+ else :
259+ _per_token_group_quant_fp8 [(M , )](
260+ x ,
261+ x_q ,
262+ x_s ,
263+ group_size ,
264+ eps ,
265+ fp8_min = fp8_min ,
266+ fp8_max = fp8_max ,
267+ BLOCK = BLOCK ,
268+ num_warps = num_warps ,
269+ num_stages = num_stages ,
270+ )
193271
194272 return x_q , x_s
195273
0 commit comments