@@ -96,21 +96,26 @@ def apply_fp8_linear(
9696 # If dynamic, layer.input_scale is None and x_scale computed from x.
9797 # If static, layer.input_scale is scalar and x_scale is input_scale.
9898
99+ # View input as 2D matrix for fp8 methods
100+ input_2d = input .view (- 1 , input .shape [- 1 ])
101+ output_shape = [* input .shape [:- 1 ], weight .shape [1 ]]
102+
99103 # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
100104 if cutlass_fp8_supported :
101105 qinput , x_scale = ops .scaled_fp8_quant (
102- input ,
106+ input_2d ,
103107 input_scale ,
104108 scale_ub = input_scale_ub ,
105109 use_per_token_if_dynamic = use_per_token_if_dynamic )
106110
107111 # Fused GEMM_DQ
108- return ops .cutlass_scaled_mm (qinput ,
109- weight ,
110- out_dtype = input .dtype ,
111- scale_a = x_scale ,
112- scale_b = weight_scale ,
113- bias = bias )
112+ output = ops .cutlass_scaled_mm (qinput ,
113+ weight ,
114+ out_dtype = input .dtype ,
115+ scale_a = x_scale ,
116+ scale_b = weight_scale ,
117+ bias = bias )
118+ return output .view (* output_shape )
114119
115120 # torch.scaled_mm supports per tensor weights + activations only
116121 # so fallback to naive if per channel or per token
@@ -119,7 +124,7 @@ def apply_fp8_linear(
119124 # for matrices with batch dimension > 16.
120125 # This could change in the future.
121126 qinput , x_scale = ops .scaled_fp8_quant (
122- input ,
127+ input_2d ,
123128 input_scale ,
124129 num_token_padding = 17 ,
125130 use_per_token_if_dynamic = use_per_token_if_dynamic )
@@ -138,8 +143,10 @@ def apply_fp8_linear(
138143 # A fix for discrepancy in scaled_mm which returns tuple
139144 # for torch < 2.5 and a single value in torch >= 2.5
140145 if type (output ) is tuple and len (output ) == 2 :
141- return torch .narrow (output [0 ], 0 , 0 , input .shape [0 ])
142- return torch .narrow (output , 0 , 0 , input .shape [0 ])
146+ output = output [0 ]
147+
148+ return torch .narrow (output , 0 , 0 ,
149+ input_2d .shape [0 ]).view (* output_shape )
143150
144151 else :
145152 # Fallback for channelwise case, where we use unfused DQ
@@ -176,15 +183,15 @@ def apply_fp8_linear(
176183 if type (output ) is tuple and len (output ) == 2 :
177184 output = output [0 ]
178185 # Unpad (undo num_token_padding)
179- output = torch .narrow (output , 0 , 0 , input .shape [0 ])
180- x_scale = torch .narrow (x_scale , 0 , 0 , input .shape [0 ])
186+ output = torch .narrow (output , 0 , 0 , input_2d .shape [0 ])
187+ x_scale = torch .narrow (x_scale , 0 , 0 , input_2d .shape [0 ])
181188
182189 # DQ
183190 # C = sw * sx * (X * W) + bias
184191 output = output * x_scale * weight_scale .t ()
185192 if bias is not None :
186193 output = output + bias
187- return output .to (dtype = input .dtype )
194+ return output .to (dtype = input .dtype ). view ( * output_shape )
188195
189196
190197def apply_int8_linear (
0 commit comments