@@ -64,9 +64,7 @@ def _test_quantized_matmul(
6464 n_input_features : int ,
6565 n_output_features : int ,
6666 quantize_activation : bool ,
67- batch_block_size = None ,
68- out_block_size = None ,
69- in_block_size = None ,
67+ tuned_value = None ,
7068 atol = 0.5 ,
7169 rtol = 0.5 ,
7270 ):
@@ -88,14 +86,13 @@ def _test_quantized_matmul(
8886 w_scale = jnp .squeeze (w_scale )
8987 assert w_scale .shape == (n_output_features , )
9088
89+ x_q_dtype = w_q .dtype if quantize_activation else dtype
9190 output = quantized_matmul_kernel (
9291 x ,
9392 w_q ,
9493 w_scale ,
95- quantize_activation = quantize_activation ,
96- batch_block_size = batch_block_size ,
97- out_block_size = out_block_size ,
98- in_block_size = in_block_size ,
94+ x_q_dtype = x_q_dtype ,
95+ tuned_value = tuned_value ,
9996 )
10097 expected = reference_quantized_matmul (
10198 x , w_q , w_scale , quantize_activation = quantize_activation )
@@ -130,9 +127,7 @@ def test_quantized_matmul_various_input_shapes(
130127 n_input_features ,
131128 n_output_features ,
132129 quantize_activation = quantize_activation ,
133- batch_block_size = 128 ,
134- out_block_size = 128 ,
135- in_block_size = 128 ,
130+ tuned_value = None ,
136131 )
137132
138133 @parameterized .product (
@@ -159,9 +154,7 @@ def test_quantized_matmul_unaligned_input_shapes(
159154 n_input_features ,
160155 n_output_features ,
161156 quantize_activation = quantize_activation ,
162- batch_block_size = 128 ,
163- out_block_size = 128 ,
164- in_block_size = 128 ,
157+ tuned_value = None ,
165158 )
166159
167160 @parameterized .parameters (
@@ -190,9 +183,7 @@ def test_quantized_matmul_use_tuned_block_sizes(
190183 n_input_features ,
191184 n_output_features ,
192185 quantize_activation = quantize_activation ,
193- batch_block_size = None ,
194- out_block_size = None ,
195- in_block_size = None ,
186+ tuned_value = None ,
196187 )
197188
198189
0 commit comments