Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

#include "paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.h"

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"
#include "paddle/fluid/pir/transforms/transform_general_functions.h"

#include "paddle/pir/include/pass/pass.h"
#include "paddle/pir/include/pass/pass_registry.h"

Expand Down Expand Up @@ -496,6 +498,177 @@ class MultiHeadMatmulFuseWithBiasQKPattern
}
};

class VitAttentionFusePattern : public paddle::drr::DrrPatternBase {
public:
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &matmul_1 = pat.Op(paddle::dialect::MatmulOp::name(),
{{"transpose_x", pat.Attr("transpose_x_1")},
{"transpose_y", pat.Attr("transpose_y_1")}});
const auto &matmul_2 = pat.Op(paddle::dialect::MatmulOp::name(),
{{"transpose_x", pat.Attr("transpose_x_2")},
{"transpose_y", pat.Attr("transpose_y_2")}});
const auto &matmul_3 = pat.Op(paddle::dialect::MatmulOp::name(),
{{"transpose_x", pat.Attr("transpose_x_3")},
{"transpose_y", pat.Attr("transpose_y_3")}});
const auto &add = pat.Op(paddle::dialect::AddOp::name());
const auto &full_int_array_1 =
pat.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pat.Attr("full_int_array_value_1")}});
const auto &reshape_1 = pat.Op(paddle::dialect::ReshapeOp::name());
const auto &full_int_array_2 =
pat.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pat.Attr("full_int_array_value_2")}});
const auto &reshape_2 = pat.Op(paddle::dialect::ReshapeOp::name());
const auto &transpose_1 = pat.Op(paddle::dialect::TransposeOp::name(),
{{"perm", pat.Attr("perm_1")}});
const auto &transpose_2 = pat.Op(paddle::dialect::TransposeOp::name(),
{{"perm", pat.Attr("perm_2")}});

const auto &transpose_3 = pat.Op(paddle::dialect::TransposeOp::name(),
{{"perm", pat.Attr("perm_3")}});
const auto &full_int_array_3 =
pat.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pat.Attr("full_int_array_value_3")}});
const auto &full_int_array_4 =
pat.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pat.Attr("full_int_array_value_4")}});
const auto &slice_1 =
pat.Op(paddle::dialect::SliceOp::name(),
{{"axes", pat.Attr("axes_1")},
{"infer_flags", pat.Attr("infer_flags_1")},
{"decrease_axis", pat.Attr("decrease_axis_1")}});
const auto &full_int_array_5 =
pat.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pat.Attr("full_int_array_value_5")}});
const auto &full_int_array_6 =
pat.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pat.Attr("full_int_array_value_6")}});
const auto &slice_2 =
pat.Op(paddle::dialect::SliceOp::name(),
{{"axes", pat.Attr("axes_2")},
{"infer_flags", pat.Attr("infer_flags_2")},
{"decrease_axis", pat.Attr("decrease_axis_2")}});
const auto &full_int_array_7 =
pat.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pat.Attr("full_int_array_value_7")}});
const auto &full_int_array_8 =
pat.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pat.Attr("full_int_array_value_8")}});
const auto &slice_3 =
pat.Op(paddle::dialect::SliceOp::name(),
{{"axes", pat.Attr("axes_3")},
{"infer_flags", pat.Attr("infer_flags_3")},
{"decrease_axis", pat.Attr("decrease_axis_3")}});
const auto &full_1 = pat.Op(paddle::dialect::FullOp::name(),
{{"value", pat.Attr("full_1_value")}});
const auto &scale =
pat.Op(paddle::dialect::ScaleOp::name(),
{{"bias", pat.Attr("scale_bias")},
{"bias_after_scale", pat.Attr("bias_after_scale")}});
const auto &softmax = pat.Op(paddle::dialect::SoftmaxOp::name(),
{{"axis", pat.Attr("axis")}});

pat.Tensor("matmul_out_1") = matmul_1(pat.Tensor("x1"), pat.Tensor("w1"));
pat.Tensor("add_1_out") =
add(pat.Tensor("matmul_out_1"), pat.Tensor("bias"));
reshape_1({&pat.Tensor("add_1_out"), &full_int_array_1()},
{&pat.Tensor("reshape_1_out"), &pat.Tensor("reshape_1_xshape")});
pat.Tensor("transpose_1_out") = transpose_1(pat.Tensor("reshape_1_out"));
pat.Tensor("slice_out_1") = slice_1(
pat.Tensor("transpose_1_out"), full_int_array_3(), full_int_array_4());
pat.Tensor("slice_out_2") = slice_2(
pat.Tensor("transpose_1_out"), full_int_array_5(), full_int_array_6());
pat.Tensor("slice_out_3") = slice_3(
pat.Tensor("transpose_1_out"), full_int_array_7(), full_int_array_8());

pat.Tensor("transpose_2_out") = transpose_2(pat.Tensor("slice_out_3"));
pat.Tensor("matmul_out_2") =
matmul_2(pat.Tensor("slice_out_2"), pat.Tensor("transpose_2_out"));
pat.Tensor("scale_out") = scale(pat.Tensor("matmul_out_2"), full_1());
pat.Tensor("softmax_out") = softmax(pat.Tensor("scale_out"));
pat.Tensor("matmul_out_3") =
matmul_3(pat.Tensor("softmax_out"), pat.Tensor("slice_out_1"));
pat.Tensor("transpose_3_out") = transpose_3(pat.Tensor("matmul_out_3"));
reshape_2({&pat.Tensor("transpose_3_out"), &full_int_array_2()},
{&pat.Tensor("reshape_2_out"), &pat.Tensor("reshape_2_xshape")});

// Constrains the activation is none
pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) {
auto softmax_axis = match_ctx.Attr<int>("axis");
if (softmax_axis != -1 && softmax_axis != 3) return false;
auto matmul_out_1_shape =
pir::GetShapeFromValue(match_ctx.Tensor("matmul_out_1"));
if (matmul_out_1_shape.size() != 3) {
return false;
}
bool matmul_1_transpose_x_1 = match_ctx.Attr<bool>("transpose_x_1");
bool matmul_1_transpose_y_1 = match_ctx.Attr<bool>("transpose_y_1");
if (matmul_1_transpose_x_1 || matmul_1_transpose_y_1) return false;
bool matmul_1_transpose_x_2 = match_ctx.Attr<bool>("transpose_x_2");
bool matmul_1_transpose_y_2 = match_ctx.Attr<bool>("transpose_y_2");
if (matmul_1_transpose_x_2 || matmul_1_transpose_y_2) return false;
bool matmul_1_transpose_x_3 = match_ctx.Attr<bool>("transpose_x_3");
bool matmul_1_transpose_y_3 = match_ctx.Attr<bool>("transpose_y_3");
if (matmul_1_transpose_x_3 || matmul_1_transpose_y_3) return false;
return true;
});

paddle::drr::ResultPattern res = pat.ResultPattern();

const auto &reshape_w_shape_attr = res.ComputeAttr(
[](const paddle::drr::MatchContext &match_ctx) -> std::vector<int64_t> {
auto w1_shape = pir::GetShapeFromValue(match_ctx.Tensor("w1"));
auto dim_0 = w1_shape.at(0);
auto dim_2 = w1_shape.at(1) / 3;
return std::vector<int64_t>({dim_0, 3, dim_2});
});

const auto &res_reshape1 = res.Op(paddle::dialect::ReshapeOp::name(),
{{"shape", reshape_w_shape_attr}});
res_reshape1({&res.Tensor("w1")},
{&res.Tensor("reshape_w_out"), &res.OutputNoneTensor()});
// Bias reshape.
const auto &reshape_b_shape_attr = res.ComputeAttr(
[](const paddle::drr::MatchContext &match_ctx) -> std::vector<int64_t> {
auto bias_shape = pir::GetShapeFromValue(match_ctx.Tensor("bias"));
auto dim = bias_shape.at(0) / 3;
return std::vector<int64_t>({3, dim});
});

const auto &res_reshape2 = res.Op(paddle::dialect::ReshapeOp::name(),
{{"shape", reshape_b_shape_attr}});
res_reshape2({&res.Tensor("bias")},
{&res.Tensor("reshape_bias_out"), &res.OutputNoneTensor()});

const auto &head_number =
res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> int {
return pir::GetShapeFromValue(match_ctx.Tensor("softmax_out")).at(1);
});

const auto &alpha = res.ComputeAttr(
[](const paddle::drr::MatchContext &match_ctx) -> float {
return match_ctx.Attr<float>("full_1_value");
});

const auto &multihead_matmul_op =
res.Op(paddle::dialect::MultiheadMatmulOp::name(),
{{
{"transpose_q", res.BoolAttr(false)},
{"transpose_k", res.BoolAttr(false)},
{"transpose_v", res.BoolAttr(false)},
{"alpha", alpha},
{"head_number", head_number},
}});
multihead_matmul_op({&res.Tensor("x1"),
&res.Tensor("reshape_w_out"),
&res.Tensor("reshape_bias_out"),
&res.InputNoneTensor()},
{&res.Tensor("reshape_2_out")});
}
std::string name() const override { return "VitAttentionFusePattern"; }
};

class MultiHeadMatmulFusePass : public pir::PatternRewritePass {
public:
MultiHeadMatmulFusePass()
Expand All @@ -505,6 +678,7 @@ class MultiHeadMatmulFusePass : public pir::PatternRewritePass {
pir::RewritePatternSet ps(context);
ps.Add(paddle::drr::Create<MultiHeadMatmulFuseNoBiasQKPattern>(context));
ps.Add(paddle::drr::Create<MultiHeadMatmulFuseWithBiasQKPattern>(context));
ps.Add(paddle::drr::Create<VitAttentionFusePattern>(context));
// Add other attention variant fuse pattern.

return ps;
Expand Down
1 change: 1 addition & 0 deletions test/ir/pir/fused_pass/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ endif()
foreach(target ${TEST_INTERP_CASES})
py_test_modules(${target} MODULES ${target})
endforeach()
set_tests_properties(test_pir_multihead_matmul_fuse_pass PROPERTIES TIMEOUT 100)
157 changes: 157 additions & 0 deletions test/ir/pir/fused_pass/test_pir_multihead_matmul_fuse_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from pass_test import PassTest

import paddle
from paddle.base import core
from paddle.pir.core import create_parameter

np.random.seed(42)
paddle.enable_static()


class TestVitAttentionPattern(PassTest):
r'''
x w
| |
matmul bias
| |
elementwise_add
|
reshape
|
transpose
/ | \
slice slice slice
| | |
| | transpose
| | |
| matmul
| |
| scale
| |
| softmax
\ /
\ /
matmul
|
transpose
|
reshape
'''

def is_program_valid(self, program):
return True

def build_ir_program(self):
for bs in [1]:
for seq_len in [128]:
for head_dim in [64]:
for num_heads in [12]:
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.pir.core.program_guard(
main_prog, start_prog
):
hidden_dim = head_dim * num_heads
x = paddle.static.data(
name='x',
shape=[bs, seq_len, hidden_dim],
dtype='float32',
)
bias = paddle.static.data(
name='bias',
shape=[3 * hidden_dim],
dtype='float32',
)

w = create_parameter(
name="w",
shape=[hidden_dim, 3 * hidden_dim],
dtype='float32',
initializer=paddle.nn.initializer.Assign(
np.random.rand(
hidden_dim, 3 * hidden_dim
).astype(np.float32)
),
)
matmul_out_1 = paddle.matmul(x, w)
add_out = paddle.add(matmul_out_1, bias)
# bs,seq_len,num_heads,3,head_dim
reshape_out_1 = paddle.reshape(
add_out,
shape=[bs, seq_len, 3, num_heads, head_dim],
)
transpose_out_1 = paddle.transpose(
reshape_out_1, perm=[2, 0, 3, 1, 4]
)
# bs,num_heads,seq_len,head_dim
q = transpose_out_1[0, :, :, :, :]
k = transpose_out_1[1, :, :, :, :]
v = transpose_out_1[2, :, :, :, :]
matmul_out_2 = paddle.matmul(
q, paddle.transpose(k, perm=[0, 1, 3, 2])
)
scale_out = paddle.scale(
matmul_out_2,
scale=0.125,
bias=0.0,
)
softmax_out = paddle.nn.functional.softmax(
scale_out
)
# bs,num_head,seq_len,head_dim
matmul_out_3 = paddle.matmul(softmax_out, v)
transpose_out_2 = paddle.transpose(
matmul_out_3, perm=[0, 2, 1, 3]
)
reshape_out_2 = paddle.reshape(
transpose_out_2,
shape=[bs, seq_len, num_heads * head_dim],
)
out = paddle.assign(reshape_out_2)
self.pass_list = ['multihead_matmul_fuse_pass']
self.feeds = {
"x": np.random.random(
(bs, seq_len, hidden_dim)
).astype("float32")
- 0.5,
"bias": np.random.random(
3 * hidden_dim
).astype("float32"),
}
self.fetch_list = [out]
self.valid_op_map = {
"pd_op.multihead_matmul": 1,
}
return [main_prog, start_prog]

def sample_program(self):
yield self.build_ir_program(), False

def setUp(self):
if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))

def test_check_output(self):
self.check_pass_correct(atol=1e-3, rtol=1e-3)


if __name__ == "__main__":
unittest.main()