@@ -284,12 +284,7 @@ def set_cdim_tile_sizes(tile_sizes, contraction_dims, csizes):
284284def generate_attention_solutions (
285285 tuner_ctx : common .TunerContext ,
286286 gpu_target_info : iree_gpu .TargetInfo ,
287- opinfo : common .AttentionOpInfo ,
288- qk_matmul : common .MatmulShapeType ,
289- pv_matmul : common .MatmulShapeType ,
290- transposed_q : bool ,
291- transposed_k : bool ,
292- transposed_v : bool ,
287+ op_info : dispatch_parser .AttentionOpInfo ,
293288 dispatch_kind : common .DispatchKind ,
294289 codegen_pipeline : iree_codegen .DispatchLoweringPassPipeline = iree_codegen .DispatchLoweringPassPipeline .LLVMGPUVectorDistribute ,
295290 num_subgroups : int = 4 ,
@@ -332,11 +327,11 @@ def generate_attention_solutions(
332327
333328 solver = z3 .Solver ()
334329 constraints = dispatch_constraints .generate_attention_vector_distribute_constraints (
335- qk_matmul ,
336- pv_matmul ,
337- transposed_q ,
338- transposed_k ,
339- transposed_v ,
330+ op_info . qk_matmul ,
331+ op_info . pv_matmul ,
332+ op_info . transposed_q ,
333+ op_info . transposed_k ,
334+ op_info . transposed_v ,
340335 [m_var , n_var , k_var ],
341336 num_subgroups ,
342337 subgroup_size ,
@@ -360,10 +355,10 @@ def generate_attention_solutions(
360355 lookup (qk_intrinsic_k ),
361356 )
362357 qk_mma_attr = dispatch_constraints .getMMAAttr (
363- qk_matmul .acc_type ,
358+ op_info . qk_matmul .acc_type ,
364359 * qk_intrinsic_mnk_shape ,
365- qk_matmul .lhs_type ,
366- qk_matmul .rhs_type ,
360+ op_info . qk_matmul .lhs_type ,
361+ op_info . qk_matmul .rhs_type ,
367362 gpu_target_info .mma_intrinsics ,
368363 )
369364
@@ -373,38 +368,38 @@ def generate_attention_solutions(
373368 lookup (pv_intrinsic_k ),
374369 )
375370 pv_mma_attr = dispatch_constraints .getMMAAttr (
376- pv_matmul .acc_type ,
371+ op_info . pv_matmul .acc_type ,
377372 * pv_intrinsic_mnk_shape ,
378- pv_matmul .lhs_type ,
379- pv_matmul .rhs_type ,
373+ op_info . pv_matmul .lhs_type ,
374+ op_info . pv_matmul .rhs_type ,
380375 gpu_target_info .mma_intrinsics ,
381376 )
382377
383378 # Get workgroup tile sizes.
384- workgroup_tile_sizes = [0 ] * opinfo .domain_rank
385- reduction_tile_sizes = [0 ] * opinfo .domain_rank
379+ workgroup_tile_sizes = [0 ] * op_info .domain_rank
380+ reduction_tile_sizes = [0 ] * op_info .domain_rank
386381
387- for b in opinfo .batch_dims :
382+ for b in op_info .batch_dims :
388383 workgroup_tile_sizes [b ] = 1
389- for m in opinfo .m_dims [:- 1 ]:
384+ for m in op_info .m_dims [:- 1 ]:
390385 workgroup_tile_sizes [m ] = 1
391- for n in opinfo .n_dims [:- 1 ]:
386+ for n in op_info .n_dims [:- 1 ]:
392387 workgroup_tile_sizes [n ] = 1
393- for k2 in opinfo .k2_dims [:- 1 ]:
388+ for k2 in op_info .k2_dims [:- 1 ]:
394389 reduction_tile_sizes [k2 ] = 1
395390
396- workgroup_tile_sizes [opinfo .m_dims [- 1 ]] = lookup (m_var )
397- workgroup_tile_sizes [opinfo .n_dims [- 1 ]] = lookup (n_var )
398- reduction_tile_sizes [opinfo .k2_dims [- 1 ]] = lookup (k_var )
391+ workgroup_tile_sizes [op_info .m_dims [- 1 ]] = lookup (m_var )
392+ workgroup_tile_sizes [op_info .n_dims [- 1 ]] = lookup (n_var )
393+ reduction_tile_sizes [op_info .k2_dims [- 1 ]] = lookup (k_var )
399394
400- subgroup_basis_counts = [1 ] * opinfo .domain_rank
401- subgroup_basis_mapping = list (range (opinfo .domain_rank ))
402- subgroup_basis_counts [opinfo .m_dims [- 1 ]] = lookup (sg_m_cnt )
403- subgroup_basis_counts [opinfo .n_dims [- 1 ]] = lookup (sg_n_cnt )
395+ subgroup_basis_counts = [1 ] * op_info .domain_rank
396+ subgroup_basis_mapping = list (range (op_info .domain_rank ))
397+ subgroup_basis_counts [op_info .m_dims [- 1 ]] = lookup (sg_m_cnt )
398+ subgroup_basis_counts [op_info .n_dims [- 1 ]] = lookup (sg_n_cnt )
404399 qk_basis_mapping = [
405400 mapping
406401 for i , mapping in enumerate (subgroup_basis_mapping )
407- if i not in opinfo .n_dims
402+ if i not in op_info .n_dims
408403 ]
409404 qk_config = {
410405 "mma_kind" : qk_mma_attr ,
@@ -419,7 +414,7 @@ def generate_attention_solutions(
419414 pv_basis_mapping = [
420415 mapping
421416 for i , mapping in enumerate (subgroup_basis_mapping )
422- if i not in opinfo .k1_dims
417+ if i not in op_info .k1_dims
423418 ]
424419 pv_config = {
425420 "mma_kind" : pv_mma_attr ,
@@ -504,13 +499,7 @@ def generate_solutions(
504499
505500
506501class ContractionOpInterfaceConstraintGenerator (ConstraintGenerator ):
507- def __init__ (
508- self , root_op : ir .Operation , op_info : dispatch_parser .ContractionOpInfo
509- ):
510- # TODO(Bangtian): Both root_op and op_info are kept as a temporary solution.
511- # Once convolution and attention ops are supported using the same structure,
512- # only op_info will be needed as it contains all necessary information.
513- self .root_op = root_op
502+ def __init__ (self , op_info : dispatch_parser .ContractionOpInfo ):
514503 self .op_info = op_info
515504
516505 def generate_solutions (
@@ -535,13 +524,7 @@ def generate_solutions(
535524
536525
537526class ConvolutionOpInterfaceConstraintGenerator (ConstraintGenerator ):
538- def __init__ (
539- self , root_op : ir .Operation , op_info : dispatch_parser .ConvolutionOpInfo
540- ):
541- # TODO(Bangtian): Both root_op and op_info are kept as a temporary solution.
542- # Once all ops are supported using the same structure, only op_info will be
543- # needed as it contains all necessary information.
544- self .root_op = root_op
527+ def __init__ (self , op_info : dispatch_parser .ConvolutionOpInfo ):
545528 self .op_info = op_info
546529
547530 def generate_solutions (
@@ -569,102 +552,14 @@ class AttentionOpInterfaceConstraintGenerator(ConstraintGenerator):
569552 """
570553 Constraint generator for the IREE LinalgExt AttentionOp.
571554
572- This class extracts structure information from the attention op and generates
573- constraints for exploring valid configurations to generate tuning specs. IREE
574- decomposes the operation into two matrix multiplications for the purpose of
575- Tiling:
576- - QK^T : Q @ K.T (producing scores)
577- - PV : P @ V (projected output after softmax)
578-
579- Assumed operand shapes:
580- - Q : [B, M, K1]
581- - K : [B, K2, K1]
582- - V : [B, K2, N]
583- - O : [B, M, N]
555+ Generates tuning configurations for attention operations.
584556
585557 Attributes:
586- transposed_q (bool): True if Q is logically transposed (k1 dim is not last in map).
587- transposed_k (bool): True if K is logically transposed (k1 dim is not last in map).
588- transposed_v (bool): True if V is logically transposed (k2 dim is not last in map).
589- qk_matmul (MatmulShapeType): Shape metadata for Q @ K^T.
590- pv_matmul (MatmulShapeType): Shape metadata for P @ V.
591- opinfo: dimensions info for attention op.
558+ op_info: AttentionOpInfo containing all attention operation metadata.
592559 """
593560
594- def __init__ (self , root_op : ir .Operation ):
595- self .root_op = root_op
596- indexing_maps_attr = root_op .attributes ["indexing_maps" ]
597- indexing_maps = [attr .value for attr in indexing_maps_attr ]
598- q_map = indexing_maps [0 ]
599- k_map = indexing_maps [1 ]
600- v_map = indexing_maps [2 ]
601- o_map = indexing_maps [- 1 ]
602-
603- raw_opinfo = iree_codegen .get_attention_op_detail (q_map , k_map , v_map , o_map )
604- assert raw_opinfo , "no attention info"
605-
606- self .opinfo = common .AttentionOpInfo (
607- domain_rank = raw_opinfo .domain_rank ,
608- batch_dims = raw_opinfo .batch_dims ,
609- m_dims = raw_opinfo .m_dims ,
610- n_dims = raw_opinfo .n_dims ,
611- k1_dims = raw_opinfo .k1_dims ,
612- k2_dims = raw_opinfo .k2_dims ,
613- )
614-
615- q_type = ir .RankedTensorType (root_op .operands [0 ].type )
616- k_type = ir .RankedTensorType (root_op .operands [1 ].type )
617- v_type = ir .RankedTensorType (root_op .operands [2 ].type )
618- q_shape = q_type .shape
619- k_shape = k_type .shape
620- v_shape = v_type .shape
621- # QK matmul uses f32 as the accumulator type to match IREE's internal assumption.
622- # PV matmul derives the accumulator type from the output tensor's element type.
623- f32_type = ir .F32Type .get ()
624- output_type = root_op .results [0 ].type .element_type
625-
626- mDim = self .opinfo .m_dims [- 1 ]
627- k1Dim = self .opinfo .k1_dims [- 1 ]
628- k2Dim = self .opinfo .k2_dims [- 1 ]
629- nDim = self .opinfo .n_dims [- 1 ]
630-
631- q_last_expr = q_map .results [- 1 ]
632- k_last_expr = k_map .results [- 1 ]
633- v_last_expr = v_map .results [- 1 ]
634-
635- q_dim_expr = ir .AffineDimExpr (q_last_expr )
636- k_dim_expr = ir .AffineDimExpr (k_last_expr )
637- v_dim_expr = ir .AffineDimExpr (v_last_expr )
638-
639- self .transposed_k = k1Dim != k_dim_expr .position
640- self .transposed_v = k2Dim != v_dim_expr .position
641- self .transposed_q = k1Dim != q_dim_expr .position
642-
643- q_dims = common .get_map_result_dim_positions (q_map )
644- k_dims = common .get_map_result_dim_positions (k_map )
645- v_dims = common .get_map_result_dim_positions (v_map )
646-
647- assert q_dims , "no query dims from attention op"
648- assert k_dims , "no key dims from attention op"
649- assert v_dims , "no value dims from attention op"
650-
651- self .qk_matmul = common .MatmulShapeType (
652- m = q_shape [q_dims .index (mDim )],
653- n = k_shape [k_dims .index (k2Dim )],
654- k = q_shape [q_dims .index (k1Dim )],
655- lhs_type = q_type .element_type ,
656- rhs_type = k_type .element_type ,
657- acc_type = f32_type ,
658- )
659-
660- self .pv_matmul = common .MatmulShapeType (
661- m = q_shape [q_dims .index (mDim )],
662- n = v_shape [v_dims .index (nDim )],
663- k = v_shape [v_dims .index (k2Dim )],
664- lhs_type = v_type .element_type ,
665- rhs_type = v_type .element_type ,
666- acc_type = output_type ,
667- )
561+ def __init__ (self , op_info : dispatch_parser .AttentionOpInfo ):
562+ self .op_info = op_info
668563
669564 def generate_solutions (
670565 self ,
@@ -676,12 +571,7 @@ def generate_solutions(
676571 return generate_attention_solutions (
677572 tuner_ctx = tuner_context ,
678573 gpu_target_info = gpu_target_info ,
679- opinfo = self .opinfo ,
680- qk_matmul = self .qk_matmul ,
681- pv_matmul = self .pv_matmul ,
682- transposed_q = self .transposed_q ,
683- transposed_k = self .transposed_k ,
684- transposed_v = self .transposed_v ,
574+ op_info = self .op_info ,
685575 dispatch_kind = common .DispatchKind .attention ,
686576 codegen_pipeline = codegen_pipeline ,
687577 ** pipeline_constraint_options ,
0 commit comments