Skip to content

Commit b826f44

Browse files
authored
[Kernel] Update tile tuning key for future mixed precision support (#715)
1 parent c1509f5 commit b826f44

File tree

6 files changed

+679
-587
lines changed

6 files changed

+679
-587
lines changed

tests/kernels/quantized_matmul_kernel_test.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ def _test_quantized_matmul(
6464
n_input_features: int,
6565
n_output_features: int,
6666
quantize_activation: bool,
67-
batch_block_size=None,
68-
out_block_size=None,
69-
in_block_size=None,
67+
tuned_value=None,
7068
atol=0.5,
7169
rtol=0.5,
7270
):
@@ -88,14 +86,13 @@ def _test_quantized_matmul(
8886
w_scale = jnp.squeeze(w_scale)
8987
assert w_scale.shape == (n_output_features, )
9088

89+
x_q_dtype = w_q.dtype if quantize_activation else dtype
9190
output = quantized_matmul_kernel(
9291
x,
9392
w_q,
9493
w_scale,
95-
quantize_activation=quantize_activation,
96-
batch_block_size=batch_block_size,
97-
out_block_size=out_block_size,
98-
in_block_size=in_block_size,
94+
x_q_dtype=x_q_dtype,
95+
tuned_value=tuned_value,
9996
)
10097
expected = reference_quantized_matmul(
10198
x, w_q, w_scale, quantize_activation=quantize_activation)
@@ -130,9 +127,7 @@ def test_quantized_matmul_various_input_shapes(
130127
n_input_features,
131128
n_output_features,
132129
quantize_activation=quantize_activation,
133-
batch_block_size=128,
134-
out_block_size=128,
135-
in_block_size=128,
130+
tuned_value=None,
136131
)
137132

138133
@parameterized.product(
@@ -159,9 +154,7 @@ def test_quantized_matmul_unaligned_input_shapes(
159154
n_input_features,
160155
n_output_features,
161156
quantize_activation=quantize_activation,
162-
batch_block_size=128,
163-
out_block_size=128,
164-
in_block_size=128,
157+
tuned_value=None,
165158
)
166159

167160
@parameterized.parameters(
@@ -190,9 +183,7 @@ def test_quantized_matmul_use_tuned_block_sizes(
190183
n_input_features,
191184
n_output_features,
192185
quantize_activation=quantize_activation,
193-
batch_block_size=None,
194-
out_block_size=None,
195-
in_block_size=None,
186+
tuned_value=None,
196187
)
197188

198189

0 commit comments

Comments
 (0)