Skip to content

Commit 93cd546

Browse files
mgoinrasmith
authored andcommitted
[Bugfix][VLM] Make apply_fp8_linear work with >2D input (vllm-project#9812)
Signed-off-by: Randall Smith <[email protected]>
1 parent 666f0b8 commit 93cd546

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,21 +96,26 @@ def apply_fp8_linear(
9696
# If dynamic, layer.input_scale is None and x_scale computed from x.
9797
# If static, layer.input_scale is scalar and x_scale is input_scale.
9898

99+
# View input as 2D matrix for fp8 methods
100+
input_2d = input.view(-1, input.shape[-1])
101+
output_shape = [*input.shape[:-1], weight.shape[1]]
102+
99103
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
100104
if cutlass_fp8_supported:
101105
qinput, x_scale = ops.scaled_fp8_quant(
102-
input,
106+
input_2d,
103107
input_scale,
104108
scale_ub=input_scale_ub,
105109
use_per_token_if_dynamic=use_per_token_if_dynamic)
106110

107111
# Fused GEMM_DQ
108-
return ops.cutlass_scaled_mm(qinput,
109-
weight,
110-
out_dtype=input.dtype,
111-
scale_a=x_scale,
112-
scale_b=weight_scale,
113-
bias=bias)
112+
output = ops.cutlass_scaled_mm(qinput,
113+
weight,
114+
out_dtype=input.dtype,
115+
scale_a=x_scale,
116+
scale_b=weight_scale,
117+
bias=bias)
118+
return output.view(*output_shape)
114119

115120
# torch.scaled_mm supports per tensor weights + activations only
116121
# so fallback to naive if per channel or per token
@@ -119,7 +124,7 @@ def apply_fp8_linear(
119124
# for matrices with batch dimension > 16.
120125
# This could change in the future.
121126
qinput, x_scale = ops.scaled_fp8_quant(
122-
input,
127+
input_2d,
123128
input_scale,
124129
num_token_padding=17,
125130
use_per_token_if_dynamic=use_per_token_if_dynamic)
@@ -138,8 +143,10 @@ def apply_fp8_linear(
138143
# A fix for discrepancy in scaled_mm which returns tuple
139144
# for torch < 2.5 and a single value in torch >= 2.5
140145
if type(output) is tuple and len(output) == 2:
141-
return torch.narrow(output[0], 0, 0, input.shape[0])
142-
return torch.narrow(output, 0, 0, input.shape[0])
146+
output = output[0]
147+
148+
return torch.narrow(output, 0, 0,
149+
input_2d.shape[0]).view(*output_shape)
143150

144151
else:
145152
# Fallback for channelwise case, where we use unfused DQ
@@ -176,15 +183,15 @@ def apply_fp8_linear(
176183
if type(output) is tuple and len(output) == 2:
177184
output = output[0]
178185
# Unpad (undo num_token_padding)
179-
output = torch.narrow(output, 0, 0, input.shape[0])
180-
x_scale = torch.narrow(x_scale, 0, 0, input.shape[0])
186+
output = torch.narrow(output, 0, 0, input_2d.shape[0])
187+
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
181188

182189
# DQ
183190
# C = sw * sx * (X * W) + bias
184191
output = output * x_scale * weight_scale.t()
185192
if bias is not None:
186193
output = output + bias
187-
return output.to(dtype=input.dtype)
194+
return output.to(dtype=input.dtype).view(*output_shape)
188195

189196

190197
def apply_int8_linear(

0 commit comments

Comments
 (0)