1515from vllm .model_executor .layers .quantization .utils .quant_utils import (
1616 _normalize_quant_group_shape , scaled_dequantize )
1717from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
18- apply_fp8_linear )
18+ CUTLASS_BLOCK_FP8_SUPPORTED , CUTLASS_FP8_SUPPORTED , apply_fp8_linear )
1919from vllm .platforms import current_platform
2020
2121logger = init_logger (__name__ )
@@ -38,7 +38,7 @@ def apply_w8a8_block_fp8_linear(
3838 weight_scale : torch .Tensor ,
3939 input_scale : Optional [torch .Tensor ] = None ,
4040 bias : Optional [torch .Tensor ] = None ,
41- cutlass_block_fp8_supported : bool = True ,
41+ cutlass_block_fp8_supported : bool = CUTLASS_BLOCK_FP8_SUPPORTED ,
4242) -> torch .Tensor :
4343 assert input_scale is None
4444 # View input as 2D matrix for fp8 methods
@@ -85,12 +85,14 @@ def apply_w8a8_block_fp8_linear(
8585# `apply_fp8_linear`
8686# NOTE(lucas): this is quite messy, we should think through this more formally
8787def apply_fp8_linear_generic (
88- input : torch .Tensor ,
89- weight : torch .Tensor ,
90- weight_scale : torch .Tensor ,
91- input_group_shape : Tuple [int , int ],
92- weight_group_shape : Tuple [int , int ],
93- input_scale : Optional [torch .Tensor ] = None , # static scale if one
88+ input : torch .Tensor ,
89+ weight : torch .Tensor ,
90+ weight_scale : torch .Tensor ,
91+ input_group_shape : Tuple [int , int ],
92+ weight_group_shape : Tuple [int , int ],
93+ input_scale : Optional [torch .Tensor ] = None , # static scale if one
94+ cutlass_fp8_supported : bool = CUTLASS_FP8_SUPPORTED ,
95+ cutlass_block_fp8_supported : bool = CUTLASS_BLOCK_FP8_SUPPORTED ,
9496) -> torch .Tensor :
9597 # View input as 2D matrix for fp8 methods
9698 input = input .view (- 1 , input .shape [- 1 ])
@@ -105,14 +107,18 @@ def is_dim_blocked(dim, shape, group_shape):
105107 if is_dim_blocked (0 , weight .shape , weight_group_shape [0 ])\
106108 and is_dim_blocked (1 , weight .shape , weight_group_shape [1 ]) and \
107109 input_group_shape == (1 , weight_group_shape [1 ]):
108- return apply_w8a8_block_fp8_linear (input , weight ,
109- list (weight_group_shape ),
110- weight_scale )
110+ return apply_w8a8_block_fp8_linear (
111+ input ,
112+ weight ,
113+ list (weight_group_shape ),
114+ weight_scale ,
115+ cutlass_block_fp8_supported = cutlass_block_fp8_supported )
111116 else :
112117 # Despite having linear in the it doesn't conform to
113118 # `torch.nn.functional.linear` which is defined as `input @ weight.T`
114119 # so we explicitly transpose the weight matrix here
115120 return apply_fp8_linear (input , weight .T , weight_scale .T ,
121+ cutlass_fp8_supported = cutlass_fp8_supported ,
116122 use_per_token_if_dynamic = \
117123 (input_group_shape == (1 , input .shape [1 ])))
118124
0 commit comments