Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/tvm/relax/backend/contrib/cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from ..pattern_registry import get_patterns_with_prefix, register_patterns
from ..patterns import make_matmul_pattern
from ..utils import has_leaking_intermediate_variables


def _is_supported_dtype(lhs_dtype, rhs_dtype):
Expand All @@ -37,6 +38,8 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype):


def _check_matmul(context: PatternCheckContext) -> bool:
if has_leaking_intermediate_variables(context):
return False
lhs = context.annotated_expr["lhs"]
rhs = context.annotated_expr["rhs"]

Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relax/backend/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from ..pattern_registry import get_patterns_with_prefix, register_patterns
from ..patterns import make_conv2d_pattern
from ..utils import has_leaking_intermediate_variables


def _is_supported_dtype(lhs_dtype, rhs_dtype):
Expand All @@ -38,6 +39,8 @@ def _is_supported_format(data_layout, kernel_layout):


def _check_conv2d(context: PatternCheckContext) -> bool:
if has_leaking_intermediate_variables(context):
return False
# Retrieve the annotated expression from context
conv2d_call = context.annotated_expr["root"]
input_expr = context.annotated_expr["input"]
Expand Down
32 changes: 5 additions & 27 deletions python/tvm/relax/backend/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
from tvm.relax import (
Call,
DataflowVar,
ExternFunc,
Function,
PyExprMutator,
Expand All @@ -47,6 +46,7 @@
make_rms_norm_pattern,
make_stacked_attention_pattern,
)
from ..utils import has_leaking_intermediate_variables


def _is_supported_dtype(lhs_dtype, rhs_dtype):
Expand All @@ -62,28 +62,6 @@ def _shape_1d(shape):
return reduce(operator.mul, shape, 1)


def _has_leaking_intermediate_variables(context: PatternCheckContext) -> bool:
"""
Check whether intermediate variables in the region to be fused are used outside
the fused region.
"""
defined_vars = set(context.matched_bindings.keys())
output_var = context.value_to_bound_var[context.matched_expr]
intermediate_vars = {v for v in context.matched_bindings if v != output_var}

if any(not isinstance(v, DataflowVar) for v in intermediate_vars):
# If intermediate variable is not a DataflowVar, it can be accessed and potentially
# used outside the DataflowBlock.
return True

# Check whether all users of an intermediate variable are inside the fused region.
for var in intermediate_vars:
if any(var_user not in defined_vars for var_user in context.var_usages[var]):
return True

return False


def _has_dependency(from_var: Var, to_var: Var, var_usages: Mapping[Var, Sequence[Var]]):
if from_var == to_var:
return True
Expand Down Expand Up @@ -137,7 +115,7 @@ def _check_residual(root_call: Call, context: PatternCheckContext) -> bool:

def _check_conv2d(context: PatternCheckContext) -> bool:
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
if _has_leaking_intermediate_variables(context):
if has_leaking_intermediate_variables(context):
return False

conv2d_call = context.annotated_expr["root"]
Expand All @@ -163,7 +141,7 @@ def _check_conv2d(context: PatternCheckContext) -> bool:

def _check_matmul(context: PatternCheckContext) -> bool:
"""Check if the given matmul workload can be offloaded to CUTLASS."""
if _has_leaking_intermediate_variables(context):
if has_leaking_intermediate_variables(context):
return False

lhs = context.annotated_expr["lhs"]
Expand Down Expand Up @@ -229,7 +207,7 @@ def _matmul_pattern(pattern_name):

def _check_decode_matmul(ctx):
"""Check if the given decode -> matmul workload can be offloaded to CUTLASS."""
if _has_leaking_intermediate_variables(ctx):
if has_leaking_intermediate_variables(ctx):
return False

root = ctx.annotated_expr["root"]
Expand Down Expand Up @@ -391,7 +369,7 @@ def residual_block_patterns():

def _check_stacked_attention(context: PatternCheckContext) -> bool:
"""Check if the given stacked attention workload can be offloaded to CUTLASS."""
if _has_leaking_intermediate_variables(context):
if has_leaking_intermediate_variables(context):
return False
if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 3:
return False
Expand Down
43 changes: 43 additions & 0 deletions python/tvm/relax/backend/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Utils for BYOC pattern matching"""

from tvm.relax import DataflowVar
from tvm.relax.transform import PatternCheckContext


def has_leaking_intermediate_variables(context: PatternCheckContext) -> bool:
"""
Check whether intermediate variables in the region to be fused are used outside
the fused region.
"""
defined_vars = set(context.matched_bindings.keys())
output_var = context.value_to_bound_var[context.matched_expr]
intermediate_vars = {v for v in context.matched_bindings if v != output_var}

if any(not isinstance(v, DataflowVar) for v in intermediate_vars):
# If intermediate variable is not a DataflowVar, it can be accessed and potentially
# used outside the DataflowBlock.
return True

# Check whether all users of an intermediate variable are inside the fused region.
for var in intermediate_vars:
if any(var_user not in defined_vars for var_user in context.var_usages[var]):
return True

return False
23 changes: 23 additions & 0 deletions tests/python/relax/test_transform_fuse_ops_by_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from tvm.relax.transform import PatternCheckContext
from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
from tvm.relax.backend.contrib.cublas import partition_for_cublas
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
Expand Down Expand Up @@ -1023,5 +1024,27 @@ def main(
assert "fused_relax_matmul_relax_add_relax_add_cutlass" in func_names


def test_intermediate_var_to_var_binding():
"""test the intermediate binding y1 will break the fusion"""

@I.ir_module
class Module:
@R.function
def main(
x: R.Tensor((1, 16), dtype="float16"), w: R.Tensor((16, 16), dtype="float16")
) -> R.Tensor((1, 16), dtype="float16"):
with R.dataflow():
w1: R.Tensor((16, 16), dtype="float16") = R.permute_dims(w, axes=None)
y: R.Tensor((1, 16), dtype="float16") = R.matmul(x, w1)
y1: R.Tensor((1, 16), dtype="float16") = y
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think such binding can be removed by CanonicalizeBinding pass

out: R.Tensor((1, 16), dtype="float16") = R.add(x, y1)
R.output(out)
return out

mod = partition_for_cublas(Module)
func_names = [name.name_hint for (name, _) in mod.functions.items()]
assert "fused_relax_permute_dims_relax_matmul_cublas" in func_names # add is not fused


if __name__ == "__main__":
pytest.main([__file__])