Skip to content

Commit ca441f4

Browse files
vinx13Archermmt
authored andcommitted
[Unity, BYOC] Add check for leaking intemediate variables for cublas and cudnn (apache#16175)
1 parent 836d648 commit ca441f4

File tree

5 files changed

+77
-27
lines changed

5 files changed

+77
-27
lines changed

python/tvm/relax/backend/contrib/cublas.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from ..pattern_registry import get_patterns_with_prefix, register_patterns
2727
from ..patterns import make_matmul_pattern
28+
from ..utils import has_leaking_intermediate_variables
2829

2930

3031
def _is_supported_dtype(lhs_dtype, rhs_dtype):
@@ -37,6 +38,8 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype):
3738

3839

3940
def _check_matmul(context: PatternCheckContext) -> bool:
41+
if has_leaking_intermediate_variables(context):
42+
return False
4043
lhs = context.annotated_expr["lhs"]
4144
rhs = context.annotated_expr["rhs"]
4245

python/tvm/relax/backend/contrib/cudnn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from ..pattern_registry import get_patterns_with_prefix, register_patterns
2323
from ..patterns import make_conv2d_pattern
24+
from ..utils import has_leaking_intermediate_variables
2425

2526

2627
def _is_supported_dtype(lhs_dtype, rhs_dtype):
@@ -38,6 +39,8 @@ def _is_supported_format(data_layout, kernel_layout):
3839

3940

4041
def _check_conv2d(context: PatternCheckContext) -> bool:
42+
if has_leaking_intermediate_variables(context):
43+
return False
4144
# Retrieve the annotated expression from context
4245
conv2d_call = context.annotated_expr["root"]
4346
input_expr = context.annotated_expr["input"]

python/tvm/relax/backend/contrib/cutlass.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
2525
from tvm.relax import (
2626
Call,
27-
DataflowVar,
2827
ExternFunc,
2928
Function,
3029
PyExprMutator,
@@ -47,6 +46,7 @@
4746
make_rms_norm_pattern,
4847
make_stacked_attention_pattern,
4948
)
49+
from ..utils import has_leaking_intermediate_variables
5050

5151

5252
def _is_supported_dtype(lhs_dtype, rhs_dtype):
@@ -62,28 +62,6 @@ def _shape_1d(shape):
6262
return reduce(operator.mul, shape, 1)
6363

6464

65-
def _has_leaking_intermediate_variables(context: PatternCheckContext) -> bool:
66-
"""
67-
Check whether intermediate variables in the region to be fused are used outside
68-
the fused region.
69-
"""
70-
defined_vars = set(context.matched_bindings.keys())
71-
output_var = context.value_to_bound_var[context.matched_expr]
72-
intermediate_vars = {v for v in context.matched_bindings if v != output_var}
73-
74-
if any(not isinstance(v, DataflowVar) for v in intermediate_vars):
75-
# If intermediate variable is not a DataflowVar, it can be accessed and potentially
76-
# used outside the DataflowBlock.
77-
return True
78-
79-
# Check whether all users of an intermediate variable are inside the fused region.
80-
for var in intermediate_vars:
81-
if any(var_user not in defined_vars for var_user in context.var_usages[var]):
82-
return True
83-
84-
return False
85-
86-
8765
def _has_dependency(from_var: Var, to_var: Var, var_usages: Mapping[Var, Sequence[Var]]):
8866
if from_var == to_var:
8967
return True
@@ -137,7 +115,7 @@ def _check_residual(root_call: Call, context: PatternCheckContext) -> bool:
137115

138116
def _check_conv2d(context: PatternCheckContext) -> bool:
139117
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
140-
if _has_leaking_intermediate_variables(context):
118+
if has_leaking_intermediate_variables(context):
141119
return False
142120

143121
conv2d_call = context.annotated_expr["root"]
@@ -163,7 +141,7 @@ def _check_conv2d(context: PatternCheckContext) -> bool:
163141

164142
def _check_matmul(context: PatternCheckContext) -> bool:
165143
"""Check if the given matmul workload can be offloaded to CUTLASS."""
166-
if _has_leaking_intermediate_variables(context):
144+
if has_leaking_intermediate_variables(context):
167145
return False
168146

169147
lhs = context.annotated_expr["lhs"]
@@ -229,7 +207,7 @@ def _matmul_pattern(pattern_name):
229207

230208
def _check_decode_matmul(ctx):
231209
"""Check if the given decode -> matmul workload can be offloaded to CUTLASS."""
232-
if _has_leaking_intermediate_variables(ctx):
210+
if has_leaking_intermediate_variables(ctx):
233211
return False
234212

235213
root = ctx.annotated_expr["root"]
@@ -391,7 +369,7 @@ def residual_block_patterns():
391369

392370
def _check_stacked_attention(context: PatternCheckContext) -> bool:
393371
"""Check if the given stacked attention workload can be offloaded to CUTLASS."""
394-
if _has_leaking_intermediate_variables(context):
372+
if has_leaking_intermediate_variables(context):
395373
return False
396374
if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 3:
397375
return False

python/tvm/relax/backend/utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name
18+
"""Utils for BYOC pattern matching"""
19+
20+
from tvm.relax import DataflowVar
21+
from tvm.relax.transform import PatternCheckContext
22+
23+
24+
def has_leaking_intermediate_variables(context: PatternCheckContext) -> bool:
25+
"""
26+
Check whether intermediate variables in the region to be fused are used outside
27+
the fused region.
28+
"""
29+
defined_vars = set(context.matched_bindings.keys())
30+
output_var = context.value_to_bound_var[context.matched_expr]
31+
intermediate_vars = {v for v in context.matched_bindings if v != output_var}
32+
33+
if any(not isinstance(v, DataflowVar) for v in intermediate_vars):
34+
# If intermediate variable is not a DataflowVar, it can be accessed and potentially
35+
# used outside the DataflowBlock.
36+
return True
37+
38+
# Check whether all users of an intermediate variable are inside the fused region.
39+
for var in intermediate_vars:
40+
if any(var_user not in defined_vars for var_user in context.var_usages[var]):
41+
return True
42+
43+
return False

tests/python/relax/test_transform_fuse_ops_by_pattern.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from tvm.relax.transform import PatternCheckContext
2929
from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
30+
from tvm.relax.backend.contrib.cublas import partition_for_cublas
3031
from tvm.script import ir as I
3132
from tvm.script import relax as R
3233
from tvm.script import tir as T
@@ -1023,5 +1024,27 @@ def main(
10231024
assert "fused_relax_matmul_relax_add_relax_add_cutlass" in func_names
10241025

10251026

1027+
def test_intermediate_var_to_var_binding():
1028+
"""test the intermediate binding y1 will break the fusion"""
1029+
1030+
@I.ir_module
1031+
class Module:
1032+
@R.function
1033+
def main(
1034+
x: R.Tensor((1, 16), dtype="float16"), w: R.Tensor((16, 16), dtype="float16")
1035+
) -> R.Tensor((1, 16), dtype="float16"):
1036+
with R.dataflow():
1037+
w1: R.Tensor((16, 16), dtype="float16") = R.permute_dims(w, axes=None)
1038+
y: R.Tensor((1, 16), dtype="float16") = R.matmul(x, w1)
1039+
y1: R.Tensor((1, 16), dtype="float16") = y
1040+
out: R.Tensor((1, 16), dtype="float16") = R.add(x, y1)
1041+
R.output(out)
1042+
return out
1043+
1044+
mod = partition_for_cublas(Module)
1045+
func_names = [name.name_hint for (name, _) in mod.functions.items()]
1046+
assert "fused_relax_permute_dims_relax_matmul_cublas" in func_names # add is not fused
1047+
1048+
10261049
if __name__ == "__main__":
10271050
pytest.main([__file__])

0 commit comments

Comments
 (0)