@@ -75,42 +75,42 @@ def fuse_qkv(W_qs, scales, zeros):
7575 """
7676 Args:
7777 W_qs (list[torch.Tensor]): len 3 list of tensors with shapes Nq x K, Nk x K, Nv x K where Nk == Nv
78- scales (list[torch.Tensor]): each is N x (K // group_size ), with same N requirements per W_qs
78+ scales (list[torch.Tensor]): each is N x (K // groupsize ), with same N requirements per W_qs
7979 zeros (list[torch.Tensor]): same as scales
8080
8181 Returns:
8282 qkv (torch.Tensor): (N_qkv x K) where N_qkv = Nq + Nk + Nv
83- scales (torch.Tensor): (N_qkv x (K // group_size ))
84- zeros (torch.Tensor): (N_qkv x (K // group_size ))
83+ scales (torch.Tensor): (N_qkv x (K // groupsize ))
84+ zeros (torch.Tensor): (N_qkv x (K // groupsize ))
8585 """
8686 qkv = torch .cat (W_qs , dim = 0 ) # Fuse along N
8787 fused_scales = torch .cat ([s for s in scales ], dim = 0 )
8888 fused_zeros = torch .cat ([z for z in zeros ], dim = 0 )
8989 return qkv , fused_scales , fused_zeros
9090
9191
92- def ref_proj (x , packed_w , scale , zero , group_size , kernel_type , transposed = False ):
92+ def ref_proj (x , packed_w , scale , zero , groupsize , kernel_type , transposed = False ):
9393 return triton_mixed_mm (
9494 x ,
9595 packed_w ,
9696 scale .T ,
9797 zero .T ,
9898 transposed = transposed ,
99- group_size = group_size ,
99+ groupsize = group_size ,
100100 fp8_fast_accum = False ,
101101 kernel_type = kernel_type ,
102102 )
103103
104104
105105@pytest .mark .parametrize (
106- "q_shape, kv_shape, group_size , axis, dtype, transposed, kernel_type" ,
106+ "q_shape, kv_shape, groupsize , axis, dtype, transposed, kernel_type" ,
107107 TEST_CONFIGS ,
108108 ids = _arg_to_id ,
109109)
110110def test_mixed_mm (
111111 q_shape ,
112112 kv_shape ,
113- group_size ,
113+ groupsize ,
114114 axis ,
115115 dtype ,
116116 transposed ,
@@ -136,7 +136,7 @@ def test_mixed_mm(
136136
137137 qcfg = {
138138 ** BASE_QUANT_CONFIG ,
139- ** dict (group_size = group_size , axis = axis ),
139+ ** dict (groupsize = group_size , axis = axis ),
140140 }
141141
142142 quant_config = BaseQuantizeConfig (
@@ -172,7 +172,7 @@ def test_mixed_mm(
172172 xs = [torch .randn (seqlen , n , dtype = dtype , device = device ) for n in Ns ]
173173 x_fused = torch .cat (xs , dim = 1 )
174174 q_ref , k_ref , v_ref = [
175- ref_proj (x , p , s , z , group_size , kernel_type , transposed = True )
175+ ref_proj (x , p , s , z , groupsize , kernel_type , transposed = True )
176176 for x , p , s , z in zip (xs , packed_ws , scales , zeros )
177177 ]
178178 tt_fused = triton_mixed_mm (
@@ -181,7 +181,7 @@ def test_mixed_mm(
181181 scales_fused .T ,
182182 zeros_fused .T ,
183183 transposed = True ,
184- group_size = group_size ,
184+ groupsize = group_size ,
185185 fp8_fast_accum = False ,
186186 kernel_type = kernel_type ,
187187 )
@@ -191,7 +191,7 @@ def test_mixed_mm(
191191 x = torch .randn (seqlen , K , dtype = dtype , device = device )
192192
193193 q_ref , k_ref , v_ref = [
194- ref_proj (x , p , s , z , group_size , kernel_type )
194+ ref_proj (x , p , s , z , groupsize , kernel_type )
195195 for p , s , z in zip (packed_ws , scales , zeros )
196196 ]
197197
@@ -201,7 +201,7 @@ def test_mixed_mm(
201201 scales_fused .T ,
202202 zeros_fused .T ,
203203 transposed = False ,
204- group_size = group_size ,
204+ groupsize = group_size ,
205205 fp8_fast_accum = False ,
206206 kernel_type = kernel_type ,
207207 )
0 commit comments