Skip to content

Commit 6a6c0fd

Browse files
authored
[tuner] use python binding to build td specs for attention (#2596)
This PR uses python binding to build td specs for attention op. The python binding is exposed from IREE PR iree-org/iree#22311. All the temporary solution code and string-based spec generation code have been removed, now SpecBuilder class can be used to build td spec using python bindings. --------- Signed-off-by: Bangtian Liu <[email protected]>
1 parent e138986 commit 6a6c0fd

File tree

7 files changed

+290
-339
lines changed

7 files changed

+290
-339
lines changed

sharktuner/sharktuner/candidate_gen.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -102,16 +102,15 @@ def supports_root_op(cls, root_op: ir.Operation) -> bool:
102102

103103
def get_constraint_generator(self) -> constraint_generator.ConstraintGenerator:
104104
return constraint_generator.ContractionOpInterfaceConstraintGenerator(
105-
self.get_root_op(), self.get_op_info()
105+
self.get_op_info()
106106
)
107107

108108
def get_td_spec(
109109
self,
110110
config_list: list[common.TuningConfiguration],
111111
) -> ir.Module:
112-
return spec_builder.build_contraction_td_spec(
113-
self._tuner_ctx, self.get_op_info(), config_list
114-
)
112+
builder = spec_builder.ContractionSpecBuilder(self.get_op_info())
113+
return builder.build_td_spec(self._tuner_ctx, config_list)
115114

116115
@classmethod
117116
def get_dispatch_kind(cls) -> common.DispatchKind:
@@ -149,16 +148,15 @@ def supports_root_op(cls, root_op: ir.Operation) -> bool:
149148

150149
def get_constraint_generator(self) -> constraint_generator.ConstraintGenerator:
151150
return constraint_generator.ConvolutionOpInterfaceConstraintGenerator(
152-
self.get_root_op(), self.get_op_info()
151+
self.get_op_info()
153152
)
154153

155154
def get_td_spec(
156155
self,
157156
config_list: list[common.TuningConfiguration],
158157
) -> ir.Module:
159-
return spec_builder.build_convolution_td_spec(
160-
self._tuner_ctx, self.get_op_info(), config_list
161-
)
158+
builder = spec_builder.ConvolutionSpecBuilder(self.get_op_info())
159+
return builder.build_td_spec(self._tuner_ctx, config_list)
162160

163161
@classmethod
164162
def get_dispatch_kind(cls) -> common.DispatchKind:
@@ -183,18 +181,15 @@ def supports_root_op(cls, root_op: ir.Operation) -> bool:
183181

184182
def get_constraint_generator(self) -> constraint_generator.ConstraintGenerator:
185183
return constraint_generator.AttentionOpInterfaceConstraintGenerator(
186-
self.get_root_op()
184+
self.get_op_info()
187185
)
188186

189187
def get_td_spec(
190188
self,
191189
config_list: list[common.TuningConfiguration],
192190
) -> ir.Module:
193-
attention_op = self.get_root_op()
194-
func_name = spec_builder.get_matcher_named_sequence_name(attention_op)
195-
return spec_builder.build_td_spec(
196-
attention_op.context, attention_op, config_list, func_name
197-
)
191+
builder = spec_builder.AttentionSpecBuilder(self.get_op_info())
192+
return builder.build_td_spec(self._tuner_ctx, config_list)
198193

199194
@classmethod
200195
def get_dispatch_kind(cls) -> common.DispatchKind:

sharktuner/sharktuner/common.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -202,16 +202,6 @@ class MatmulShapeType:
202202
acc_type: ir.IntegerType | ir.FloatType
203203

204204

205-
@dataclass
206-
class AttentionOpInfo:
207-
domain_rank: int
208-
batch_dims: list[int]
209-
m_dims: list[int]
210-
n_dims: list[int]
211-
k1_dims: list[int]
212-
k2_dims: list[int]
213-
214-
215205
@dataclass
216206
class LLVMGPUVectorDistributeContractionKnobs(KnobAssignment):
217207
# Z3 numeric selections.

sharktuner/sharktuner/constraint_generator.py

Lines changed: 34 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,7 @@ def set_cdim_tile_sizes(tile_sizes, contraction_dims, csizes):
284284
def generate_attention_solutions(
285285
tuner_ctx: common.TunerContext,
286286
gpu_target_info: iree_gpu.TargetInfo,
287-
opinfo: common.AttentionOpInfo,
288-
qk_matmul: common.MatmulShapeType,
289-
pv_matmul: common.MatmulShapeType,
290-
transposed_q: bool,
291-
transposed_k: bool,
292-
transposed_v: bool,
287+
op_info: dispatch_parser.AttentionOpInfo,
293288
dispatch_kind: common.DispatchKind,
294289
codegen_pipeline: iree_codegen.DispatchLoweringPassPipeline = iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute,
295290
num_subgroups: int = 4,
@@ -332,11 +327,11 @@ def generate_attention_solutions(
332327

333328
solver = z3.Solver()
334329
constraints = dispatch_constraints.generate_attention_vector_distribute_constraints(
335-
qk_matmul,
336-
pv_matmul,
337-
transposed_q,
338-
transposed_k,
339-
transposed_v,
330+
op_info.qk_matmul,
331+
op_info.pv_matmul,
332+
op_info.transposed_q,
333+
op_info.transposed_k,
334+
op_info.transposed_v,
340335
[m_var, n_var, k_var],
341336
num_subgroups,
342337
subgroup_size,
@@ -360,10 +355,10 @@ def generate_attention_solutions(
360355
lookup(qk_intrinsic_k),
361356
)
362357
qk_mma_attr = dispatch_constraints.getMMAAttr(
363-
qk_matmul.acc_type,
358+
op_info.qk_matmul.acc_type,
364359
*qk_intrinsic_mnk_shape,
365-
qk_matmul.lhs_type,
366-
qk_matmul.rhs_type,
360+
op_info.qk_matmul.lhs_type,
361+
op_info.qk_matmul.rhs_type,
367362
gpu_target_info.mma_intrinsics,
368363
)
369364

@@ -373,38 +368,38 @@ def generate_attention_solutions(
373368
lookup(pv_intrinsic_k),
374369
)
375370
pv_mma_attr = dispatch_constraints.getMMAAttr(
376-
pv_matmul.acc_type,
371+
op_info.pv_matmul.acc_type,
377372
*pv_intrinsic_mnk_shape,
378-
pv_matmul.lhs_type,
379-
pv_matmul.rhs_type,
373+
op_info.pv_matmul.lhs_type,
374+
op_info.pv_matmul.rhs_type,
380375
gpu_target_info.mma_intrinsics,
381376
)
382377

383378
# Get workgroup tile sizes.
384-
workgroup_tile_sizes = [0] * opinfo.domain_rank
385-
reduction_tile_sizes = [0] * opinfo.domain_rank
379+
workgroup_tile_sizes = [0] * op_info.domain_rank
380+
reduction_tile_sizes = [0] * op_info.domain_rank
386381

387-
for b in opinfo.batch_dims:
382+
for b in op_info.batch_dims:
388383
workgroup_tile_sizes[b] = 1
389-
for m in opinfo.m_dims[:-1]:
384+
for m in op_info.m_dims[:-1]:
390385
workgroup_tile_sizes[m] = 1
391-
for n in opinfo.n_dims[:-1]:
386+
for n in op_info.n_dims[:-1]:
392387
workgroup_tile_sizes[n] = 1
393-
for k2 in opinfo.k2_dims[:-1]:
388+
for k2 in op_info.k2_dims[:-1]:
394389
reduction_tile_sizes[k2] = 1
395390

396-
workgroup_tile_sizes[opinfo.m_dims[-1]] = lookup(m_var)
397-
workgroup_tile_sizes[opinfo.n_dims[-1]] = lookup(n_var)
398-
reduction_tile_sizes[opinfo.k2_dims[-1]] = lookup(k_var)
391+
workgroup_tile_sizes[op_info.m_dims[-1]] = lookup(m_var)
392+
workgroup_tile_sizes[op_info.n_dims[-1]] = lookup(n_var)
393+
reduction_tile_sizes[op_info.k2_dims[-1]] = lookup(k_var)
399394

400-
subgroup_basis_counts = [1] * opinfo.domain_rank
401-
subgroup_basis_mapping = list(range(opinfo.domain_rank))
402-
subgroup_basis_counts[opinfo.m_dims[-1]] = lookup(sg_m_cnt)
403-
subgroup_basis_counts[opinfo.n_dims[-1]] = lookup(sg_n_cnt)
395+
subgroup_basis_counts = [1] * op_info.domain_rank
396+
subgroup_basis_mapping = list(range(op_info.domain_rank))
397+
subgroup_basis_counts[op_info.m_dims[-1]] = lookup(sg_m_cnt)
398+
subgroup_basis_counts[op_info.n_dims[-1]] = lookup(sg_n_cnt)
404399
qk_basis_mapping = [
405400
mapping
406401
for i, mapping in enumerate(subgroup_basis_mapping)
407-
if i not in opinfo.n_dims
402+
if i not in op_info.n_dims
408403
]
409404
qk_config = {
410405
"mma_kind": qk_mma_attr,
@@ -419,7 +414,7 @@ def generate_attention_solutions(
419414
pv_basis_mapping = [
420415
mapping
421416
for i, mapping in enumerate(subgroup_basis_mapping)
422-
if i not in opinfo.k1_dims
417+
if i not in op_info.k1_dims
423418
]
424419
pv_config = {
425420
"mma_kind": pv_mma_attr,
@@ -504,13 +499,7 @@ def generate_solutions(
504499

505500

506501
class ContractionOpInterfaceConstraintGenerator(ConstraintGenerator):
507-
def __init__(
508-
self, root_op: ir.Operation, op_info: dispatch_parser.ContractionOpInfo
509-
):
510-
# TODO(Bangtian): Both root_op and op_info are kept as a temporary solution.
511-
# Once convolution and attention ops are supported using the same structure,
512-
# only op_info will be needed as it contains all necessary information.
513-
self.root_op = root_op
502+
def __init__(self, op_info: dispatch_parser.ContractionOpInfo):
514503
self.op_info = op_info
515504

516505
def generate_solutions(
@@ -535,13 +524,7 @@ def generate_solutions(
535524

536525

537526
class ConvolutionOpInterfaceConstraintGenerator(ConstraintGenerator):
538-
def __init__(
539-
self, root_op: ir.Operation, op_info: dispatch_parser.ConvolutionOpInfo
540-
):
541-
# TODO(Bangtian): Both root_op and op_info are kept as a temporary solution.
542-
# Once all ops are supported using the same structure, only op_info will be
543-
# needed as it contains all necessary information.
544-
self.root_op = root_op
527+
def __init__(self, op_info: dispatch_parser.ConvolutionOpInfo):
545528
self.op_info = op_info
546529

547530
def generate_solutions(
@@ -569,102 +552,14 @@ class AttentionOpInterfaceConstraintGenerator(ConstraintGenerator):
569552
"""
570553
Constraint generator for the IREE LinalgExt AttentionOp.
571554
572-
This class extracts structure information from the attention op and generates
573-
constraints for exploring valid configurations to generate tuning specs. IREE
574-
decomposes the operation into two matrix multiplications for the purpose of
575-
Tiling:
576-
- QK^T : Q @ K.T (producing scores)
577-
- PV : P @ V (projected output after softmax)
578-
579-
Assumed operand shapes:
580-
- Q : [B, M, K1]
581-
- K : [B, K2, K1]
582-
- V : [B, K2, N]
583-
- O : [B, M, N]
555+
Generates tuning configurations for attention operations.
584556
585557
Attributes:
586-
transposed_q (bool): True if Q is logically transposed (k1 dim is not last in map).
587-
transposed_k (bool): True if K is logically transposed (k1 dim is not last in map).
588-
transposed_v (bool): True if V is logically transposed (k2 dim is not last in map).
589-
qk_matmul (MatmulShapeType): Shape metadata for Q @ K^T.
590-
pv_matmul (MatmulShapeType): Shape metadata for P @ V.
591-
opinfo: dimensions info for attention op.
558+
op_info: AttentionOpInfo containing all attention operation metadata.
592559
"""
593560

594-
def __init__(self, root_op: ir.Operation):
595-
self.root_op = root_op
596-
indexing_maps_attr = root_op.attributes["indexing_maps"]
597-
indexing_maps = [attr.value for attr in indexing_maps_attr]
598-
q_map = indexing_maps[0]
599-
k_map = indexing_maps[1]
600-
v_map = indexing_maps[2]
601-
o_map = indexing_maps[-1]
602-
603-
raw_opinfo = iree_codegen.get_attention_op_detail(q_map, k_map, v_map, o_map)
604-
assert raw_opinfo, "no attention info"
605-
606-
self.opinfo = common.AttentionOpInfo(
607-
domain_rank=raw_opinfo.domain_rank,
608-
batch_dims=raw_opinfo.batch_dims,
609-
m_dims=raw_opinfo.m_dims,
610-
n_dims=raw_opinfo.n_dims,
611-
k1_dims=raw_opinfo.k1_dims,
612-
k2_dims=raw_opinfo.k2_dims,
613-
)
614-
615-
q_type = ir.RankedTensorType(root_op.operands[0].type)
616-
k_type = ir.RankedTensorType(root_op.operands[1].type)
617-
v_type = ir.RankedTensorType(root_op.operands[2].type)
618-
q_shape = q_type.shape
619-
k_shape = k_type.shape
620-
v_shape = v_type.shape
621-
# QK matmul uses f32 as the accumulator type to match IREE's internal assumption.
622-
# PV matmul derives the accumulator type from the output tensor's element type.
623-
f32_type = ir.F32Type.get()
624-
output_type = root_op.results[0].type.element_type
625-
626-
mDim = self.opinfo.m_dims[-1]
627-
k1Dim = self.opinfo.k1_dims[-1]
628-
k2Dim = self.opinfo.k2_dims[-1]
629-
nDim = self.opinfo.n_dims[-1]
630-
631-
q_last_expr = q_map.results[-1]
632-
k_last_expr = k_map.results[-1]
633-
v_last_expr = v_map.results[-1]
634-
635-
q_dim_expr = ir.AffineDimExpr(q_last_expr)
636-
k_dim_expr = ir.AffineDimExpr(k_last_expr)
637-
v_dim_expr = ir.AffineDimExpr(v_last_expr)
638-
639-
self.transposed_k = k1Dim != k_dim_expr.position
640-
self.transposed_v = k2Dim != v_dim_expr.position
641-
self.transposed_q = k1Dim != q_dim_expr.position
642-
643-
q_dims = common.get_map_result_dim_positions(q_map)
644-
k_dims = common.get_map_result_dim_positions(k_map)
645-
v_dims = common.get_map_result_dim_positions(v_map)
646-
647-
assert q_dims, "no query dims from attention op"
648-
assert k_dims, "no key dims from attention op"
649-
assert v_dims, "no value dims from attention op"
650-
651-
self.qk_matmul = common.MatmulShapeType(
652-
m=q_shape[q_dims.index(mDim)],
653-
n=k_shape[k_dims.index(k2Dim)],
654-
k=q_shape[q_dims.index(k1Dim)],
655-
lhs_type=q_type.element_type,
656-
rhs_type=k_type.element_type,
657-
acc_type=f32_type,
658-
)
659-
660-
self.pv_matmul = common.MatmulShapeType(
661-
m=q_shape[q_dims.index(mDim)],
662-
n=v_shape[v_dims.index(nDim)],
663-
k=v_shape[v_dims.index(k2Dim)],
664-
lhs_type=v_type.element_type,
665-
rhs_type=v_type.element_type,
666-
acc_type=output_type,
667-
)
561+
def __init__(self, op_info: dispatch_parser.AttentionOpInfo):
562+
self.op_info = op_info
668563

669564
def generate_solutions(
670565
self,
@@ -676,12 +571,7 @@ def generate_solutions(
676571
return generate_attention_solutions(
677572
tuner_ctx=tuner_context,
678573
gpu_target_info=gpu_target_info,
679-
opinfo=self.opinfo,
680-
qk_matmul=self.qk_matmul,
681-
pv_matmul=self.pv_matmul,
682-
transposed_q=self.transposed_q,
683-
transposed_k=self.transposed_k,
684-
transposed_v=self.transposed_v,
574+
op_info=self.op_info,
685575
dispatch_kind=common.DispatchKind.attention,
686576
codegen_pipeline=codegen_pipeline,
687577
**pipeline_constraint_options,

0 commit comments

Comments
 (0)