@@ -268,14 +268,23 @@ def __init__(
268268 self .input_quant = self .quant_config .target_scheme_map ["Linear" ].get (
269269 "input_activations" )
270270
271- if not (self .weight_quant .strategy == QuantizationStrategy .TENSOR
272- and self .input_quant .strategy == QuantizationStrategy .TENSOR ):
271+ per_tensor = (self .weight_quant .strategy == QuantizationStrategy .TENSOR
272+ and self .input_quant .strategy
273+ == QuantizationStrategy .TENSOR )
274+ per_channel = (
275+ self .weight_quant .strategy == QuantizationStrategy .CHANNEL
276+ and self .input_quant .strategy == QuantizationStrategy .TOKEN )
277+ if not (per_tensor or per_channel ):
273278 raise ValueError (
274- "For FP8 Fused MoE layers, only per-tensor scales "
275- "for weights and activations are supported . Found "
279+ "For FP8 Fused MoE layers, we require per tensor "
280+ "or channelwise, dynamic per token quantization . Found "
276281 f"{ self .weight_quant } , { self .input_quant } " )
277282
278283 self .static_input_scales = not self .input_quant .dynamic
284+ if self .static_input_scales and per_channel :
285+ raise ValueError (
286+ "For FP8 Fused MoE layer, we require either per tensor or "
287+ "channelwise, dynamic per token quantization." )
279288
280289 def create_weights (self , layer : torch .nn .Module , num_experts : int ,
281290 hidden_size : int , intermediate_size_per_partition : int ,
@@ -303,24 +312,40 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
303312 set_weight_attrs (w2_weight , extra_weight_attrs )
304313
305314 # WEIGHT_SCALES
306- # Allocate 2 scales for w1 and w3 respectively.
307- # They will be combined to a single scale after weight loading.
308- w13_weight_scale = torch .nn .Parameter (torch .ones (num_experts ,
309- 2 ,
310- dtype = torch .float32 ),
311- requires_grad = False )
312- layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
313-
314- w2_weight_scale = torch .nn .Parameter (torch .ones (num_experts ,
315- dtype = torch .float32 ),
316- requires_grad = False )
317- layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
318- # Add the quantization method used (per tensor/grouped/channel)
319- # to ensure the weight scales are loaded in properly
320- extra_weight_attrs .update (
321- {"quant_method" : FusedMoeWeightScaleSupported .TENSOR .value })
322- set_weight_attrs (w13_weight_scale , extra_weight_attrs )
323- set_weight_attrs (w2_weight_scale , extra_weight_attrs )
315+ if self .weight_quant .strategy == QuantizationStrategy .TENSOR :
316+ # Allocate 2 scales for w1 and w3 respectively.
317+ # They are combined to a single scale after weight loading.
318+ w13_weight_scale = torch .nn .Parameter (torch .ones (
319+ num_experts , 2 , dtype = torch .float32 ),
320+ requires_grad = False )
321+ layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
322+ w2_weight_scale = torch .nn .Parameter (torch .ones (
323+ num_experts , dtype = torch .float32 ),
324+ requires_grad = False )
325+ layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
326+ # Add PER-TENSOR quantization for FusedMoE.weight_loader.
327+ extra_weight_attrs .update (
328+ {"quant_method" : FusedMoeWeightScaleSupported .TENSOR .value })
329+ set_weight_attrs (w13_weight_scale , extra_weight_attrs )
330+ set_weight_attrs (w2_weight_scale , extra_weight_attrs )
331+
332+ elif self .weight_quant .strategy == QuantizationStrategy .CHANNEL :
333+ w13_weight_scale = torch .nn .Parameter (torch .ones (
334+ num_experts ,
335+ 2 * intermediate_size_per_partition ,
336+ 1 ,
337+ dtype = torch .float32 ),
338+ requires_grad = False )
339+ layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
340+ w2_weight_scale = torch .nn .Parameter (torch .ones (
341+ num_experts , hidden_size , 1 , dtype = torch .float32 ),
342+ requires_grad = False )
343+ layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
344+ # Add PER-CHANNEL quantization for FusedMoE.weight_loader.
345+ extra_weight_attrs .update (
346+ {"quant_method" : FusedMoeWeightScaleSupported .CHANNEL .value })
347+ set_weight_attrs (w13_weight_scale , extra_weight_attrs )
348+ set_weight_attrs (w2_weight_scale , extra_weight_attrs )
324349
325350 # INPUT_SCALES
326351 if self .static_input_scales :
@@ -362,6 +387,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
362387 # Fp8 moe kernels require a single activation scale.
363388 # We take the max of all the scales in case they differ.
364389 if self .static_input_scales :
390+ assert self .input_quant .strategy == QuantizationStrategy .TENSOR
365391 if (layer .w13_input_scale is None or layer .w2_input_scale is None ):
366392 raise ValueError (
367393 "QuantConfig has static quantization, but found "
@@ -377,24 +403,25 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
377403 layer .w2_input_scale = torch .nn .Parameter (
378404 layer .w2_input_scale .max (), requires_grad = False )
379405
380- # Fp8 moe kernel needs single weight scale for w13 per expert.
381- # We take the max then dequant and requant each expert.
382- assert layer .w13_weight_scale is not None
383- shard_size = layer .intermediate_size_per_partition
384- max_w13_scales = layer .w13_weight_scale .max (dim = 1 ).values
385- for expert_id in range (layer .local_num_experts ):
386- start = 0
387- for shard_id in range (2 ):
388- dq_weight = per_tensor_dequantize (
389- layer .w13_weight [expert_id ][start :start + shard_size , :],
390- layer .w13_weight_scale [expert_id ][shard_id ])
391- layer .w13_weight [expert_id ][
392- start :start + shard_size , :], _ = ops .scaled_fp8_quant (
393- dq_weight , max_w13_scales [expert_id ])
394- start += shard_size
395-
396- layer .w13_weight_scale = torch .nn .Parameter (max_w13_scales ,
397- requires_grad = False )
406+ # For Per-TENSOR case, Fp8 moe kernel needs single weight scale
407+ # for w13 per expert. Use max then dequant and requant each expert.
408+ if self .weight_quant .strategy == QuantizationStrategy .TENSOR :
409+ assert layer .w13_weight_scale is not None
410+ shard_size = layer .intermediate_size_per_partition
411+ max_w13_scales = layer .w13_weight_scale .max (dim = 1 ).values
412+ for expert_id in range (layer .local_num_experts ):
413+ start = 0
414+ for shard_id in range (2 ):
415+ dq_weight = per_tensor_dequantize (
416+ layer .w13_weight [expert_id ][start :start +
417+ shard_size , :],
418+ layer .w13_weight_scale [expert_id ][shard_id ])
419+ layer .w13_weight [expert_id ][
420+ start :start + shard_size , :], _ = ops .scaled_fp8_quant (
421+ dq_weight , max_w13_scales [expert_id ])
422+ start += shard_size
423+ layer .w13_weight_scale = torch .nn .Parameter (max_w13_scales ,
424+ requires_grad = False )
398425
399426 def apply (
400427 self ,
0 commit comments