2424from tvm .contrib .cutlass .build import is_shape_valid_for_cutlass_matmul
2525from tvm .relax import (
2626 Call ,
27- DataflowVar ,
2827 ExternFunc ,
2928 Function ,
3029 PyExprMutator ,
4746 make_rms_norm_pattern ,
4847 make_stacked_attention_pattern ,
4948)
49+ from ..utils import has_leaking_intermediate_variables
5050
5151
5252def _is_supported_dtype (lhs_dtype , rhs_dtype ):
@@ -62,28 +62,6 @@ def _shape_1d(shape):
6262 return reduce (operator .mul , shape , 1 )
6363
6464
65- def _has_leaking_intermediate_variables (context : PatternCheckContext ) -> bool :
66- """
67- Check whether intermediate variables in the region to be fused are used outside
68- the fused region.
69- """
70- defined_vars = set (context .matched_bindings .keys ())
71- output_var = context .value_to_bound_var [context .matched_expr ]
72- intermediate_vars = {v for v in context .matched_bindings if v != output_var }
73-
74- if any (not isinstance (v , DataflowVar ) for v in intermediate_vars ):
75- # If intermediate variable is not a DataflowVar, it can be accessed and potentially
76- # used outside the DataflowBlock.
77- return True
78-
79- # Check whether all users of an intermediate variable are inside the fused region.
80- for var in intermediate_vars :
81- if any (var_user not in defined_vars for var_user in context .var_usages [var ]):
82- return True
83-
84- return False
85-
86-
8765def _has_dependency (from_var : Var , to_var : Var , var_usages : Mapping [Var , Sequence [Var ]]):
8866 if from_var == to_var :
8967 return True
@@ -137,7 +115,7 @@ def _check_residual(root_call: Call, context: PatternCheckContext) -> bool:
137115
138116def _check_conv2d (context : PatternCheckContext ) -> bool :
139117 """Check if the given conv2d workload can be offloaded to CUTLASS."""
140- if _has_leaking_intermediate_variables (context ):
118+ if has_leaking_intermediate_variables (context ):
141119 return False
142120
143121 conv2d_call = context .annotated_expr ["root" ]
@@ -163,7 +141,7 @@ def _check_conv2d(context: PatternCheckContext) -> bool:
163141
164142def _check_matmul (context : PatternCheckContext ) -> bool :
165143 """Check if the given matmul workload can be offloaded to CUTLASS."""
166- if _has_leaking_intermediate_variables (context ):
144+ if has_leaking_intermediate_variables (context ):
167145 return False
168146
169147 lhs = context .annotated_expr ["lhs" ]
@@ -229,7 +207,7 @@ def _matmul_pattern(pattern_name):
229207
230208def _check_decode_matmul (ctx ):
231209 """Check if the given decode -> matmul workload can be offloaded to CUTLASS."""
232- if _has_leaking_intermediate_variables (ctx ):
210+ if has_leaking_intermediate_variables (ctx ):
233211 return False
234212
235213 root = ctx .annotated_expr ["root" ]
@@ -391,7 +369,7 @@ def residual_block_patterns():
391369
392370def _check_stacked_attention (context : PatternCheckContext ) -> bool :
393371 """Check if the given stacked attention workload can be offloaded to CUTLASS."""
394- if _has_leaking_intermediate_variables (context ):
372+ if has_leaking_intermediate_variables (context ):
395373 return False
396374 if not context .annotated_expr ["stacked_qkv" ].struct_info .ndim == 3 :
397375 return False
0 commit comments