|
| 1 | +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include <memory> |
| 16 | +#include <string> |
| 17 | +#include "lite/backends/xpu/math.h" |
| 18 | +#include "lite/core/optimizer/mir/pass_registry.h" |
| 19 | +#include "lite/core/optimizer/mir/pattern_matcher_high_api.h" |
| 20 | + |
| 21 | +namespace paddle { |
| 22 | +namespace lite { |
| 23 | +namespace mir { |
| 24 | +namespace fusion { |
| 25 | + |
| 26 | +/* support adaptive seq len for mrc */ |
| 27 | +/* in_Input in_Mask */ |
| 28 | +/* | | */ |
| 29 | +/* | | */ |
| 30 | +/* | matmul */ |
| 31 | +/* | | */ |
| 32 | +/* | | */ |
| 33 | +/* | scale */ |
| 34 | +/* | / */ |
| 35 | +/* | stack */ |
| 36 | +/* | | */ |
| 37 | +/* | / */ |
| 38 | +/* | / */ |
| 39 | +/* xpu_encoder */ |
| 40 | +/* | */ |
| 41 | +/* | */ |
| 42 | +/* out_Output */ |
| 43 | +/*-------------------------------------------*/ |
| 44 | +/* After the pass apply: */ |
| 45 | +/* in_Input in_Mask */ |
| 46 | +/* | | */ |
| 47 | +/* | | */ |
| 48 | +/* | xpu_adaptive_mask */ |
| 49 | +/* | | | */ |
| 50 | +/* sequence_unpad<--Lenght | */ |
| 51 | +/* | | */ |
| 52 | +/* | PadSeqLen */ |
| 53 | +/* | SeqLod */ |
| 54 | +/* | / */ |
| 55 | +/* | / */ |
| 56 | +/* | / */ |
| 57 | +/* xpu_encoder */ |
| 58 | +/* | */ |
| 59 | +/* | */ |
| 60 | +/* out_Output */ |
| 61 | +/*-------------------------------------------*/ |
| 62 | + |
| 63 | +class XPUMultiEncoderAdaptiveSeqlenV3Fuser : public FuseBase { |
| 64 | + public: |
| 65 | + explicit XPUMultiEncoderAdaptiveSeqlenV3Fuser( |
| 66 | + const std::string& matmul_type = "matmul") |
| 67 | + : matmul_type_(matmul_type) {} |
| 68 | + |
| 69 | + void BuildPattern() override { |
| 70 | + auto* mask = VarNode("mask") |
| 71 | + ->assert_is_op_input(matmul_type_, "X") |
| 72 | + ->assert_is_op_input(matmul_type_, "Y"); |
| 73 | + auto* matmul = OpNode(matmul_type_, matmul_type_)->AsIntermediate(); |
| 74 | + auto* matmul_out = VarNode("matmul_out") |
| 75 | + ->assert_is_op_input("scale", "X") |
| 76 | + ->assert_is_op_output(matmul_type_, "Out") |
| 77 | + ->AsIntermediate(); |
| 78 | + auto* scale = |
| 79 | + OpNode("scale", "scale") |
| 80 | + ->assert_op_attr<bool>("bias_after_scale", false) |
| 81 | + ->assert_op_attr_satisfied<float>( |
| 82 | + "bias", |
| 83 | + [](float attr) { return (std::fabs(attr + 1.0) < 1e-5); }) |
| 84 | + ->assert_op_attr_satisfied<float>( |
| 85 | + "scale", |
| 86 | + [](float attr) { return (std::fabs(attr - 10000.0) < 1e-5); }) |
| 87 | + ->AsIntermediate(); |
| 88 | + auto* scale_out = VarNode("scale_out") |
| 89 | + ->assert_is_op_input("stack", "X") |
| 90 | + ->assert_is_op_output("scale", "Out") |
| 91 | + ->AsIntermediate(); |
| 92 | + auto* stack = OpNode("stack", "stack")->AsIntermediate(); |
| 93 | + auto* stack_out = VarNode("stack_out") |
| 94 | + ->assert_is_op_input("__xpu__multi_encoder", "Mask") |
| 95 | + ->assert_is_op_output("stack", "Y") |
| 96 | + ->AsIntermediate(); |
| 97 | + auto* encoder_input = |
| 98 | + VarNode("encoder_input") |
| 99 | + ->assert_is_op_input("__xpu__multi_encoder", "Input"); |
| 100 | + auto* xpu_encoder = OpNode("xpu_encoder", "__xpu__multi_encoder") |
| 101 | + ->assert_op_attr<bool>("adaptive_seqlen", true); |
| 102 | + |
| 103 | + *mask >> *matmul >> *matmul_out >> *scale >> *scale_out >> *stack >> |
| 104 | + *stack_out >> *xpu_encoder; |
| 105 | + *encoder_input >> *xpu_encoder; |
| 106 | + } |
| 107 | + |
| 108 | + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { |
| 109 | + auto* encoder_instruct = matched.at("xpu_encoder")->stmt(); |
| 110 | + auto encoder_op_desc = encoder_instruct->mutable_op_info(); |
| 111 | + auto encoder_op = encoder_instruct->op(); |
| 112 | + auto* scope = encoder_op->scope(); |
| 113 | + |
| 114 | + // add new arg seq_lod |
| 115 | + std::string stack_out_name = matched.at("stack_out")->arg()->name; |
| 116 | + std::string xpu_mask_adaptive_seq_lod_name = stack_out_name + "_seq_lod"; |
| 117 | + auto* xpu_mask_adaptive_seq_lod_node = |
| 118 | + graph->NewArgumentNode(xpu_mask_adaptive_seq_lod_name); |
| 119 | + xpu_mask_adaptive_seq_lod_node->arg()->type = LiteType::GetTensorTy( |
| 120 | + TARGET(kHost), PRECISION(kInt32), DATALAYOUT(kNCHW)); |
| 121 | + scope->NewTensor(xpu_mask_adaptive_seq_lod_name); |
| 122 | + // add new arg pad_seq_len, store max padded length |
| 123 | + std::string xpu_mask_adaptive_pad_seq_len_name = |
| 124 | + stack_out_name + "_pad_seq_len"; |
| 125 | + auto* xpu_mask_adaptive_pad_seq_len_node = |
| 126 | + graph->NewArgumentNode(xpu_mask_adaptive_pad_seq_len_name); |
| 127 | + xpu_mask_adaptive_pad_seq_len_node->arg()->type = LiteType::GetTensorTy( |
| 128 | + TARGET(kHost), PRECISION(kInt32), DATALAYOUT(kNCHW)); |
| 129 | + scope->NewTensor(xpu_mask_adaptive_pad_seq_len_name); |
| 130 | + // add new arg length, for sequence_unpad, store length in batch |
| 131 | + std::string xpu_mask_adaptive_seq_len_name = stack_out_name + "_seq_length"; |
| 132 | + auto* xpu_mask_adaptive_seq_len_node = |
| 133 | + graph->NewArgumentNode(xpu_mask_adaptive_seq_len_name); |
| 134 | + xpu_mask_adaptive_seq_len_node->arg()->type = LiteType::GetTensorTy( |
| 135 | + TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kNCHW)); |
| 136 | + scope->NewTensor(xpu_mask_adaptive_seq_len_name); |
| 137 | + |
| 138 | + // add new packed input of encoder |
| 139 | + std::string orig_encoder_input_name = |
| 140 | + matched.at("encoder_input")->arg()->name; |
| 141 | + std::string packed_encoder_input_name = |
| 142 | + orig_encoder_input_name + "_vsl_packed"; |
| 143 | + auto* packed_encoder_input_node = |
| 144 | + graph->NewArgumentNode(packed_encoder_input_name); |
| 145 | + packed_encoder_input_node->arg()->type = LiteType::GetTensorTy( |
| 146 | + TARGET(kXPU), PRECISION(kFloat), DATALAYOUT(kNCHW)); |
| 147 | + scope->NewTensor(packed_encoder_input_name); |
| 148 | + |
| 149 | + // create xpu_mask_adaptive op to set lod |
| 150 | + cpp::OpDesc op_desc; |
| 151 | + op_desc.SetType("__xpu__mask_adaptive"); |
| 152 | + op_desc.SetInput("Mask", {matched.at("mask")->arg()->name}); |
| 153 | + op_desc.SetOutput( |
| 154 | + "Length", |
| 155 | + {xpu_mask_adaptive_seq_len_name}); // length for sequence_unpad op |
| 156 | + op_desc.SetOutput("SeqLod", |
| 157 | + {xpu_mask_adaptive_seq_lod_name}); // lod for encoder op |
| 158 | + op_desc.SetOutput("PadSeqLen", {xpu_mask_adaptive_pad_seq_len_name}); |
| 159 | + auto xpu_mask_adaptive_op = |
| 160 | + LiteOpRegistry::Global().Create("__xpu__mask_adaptive"); |
| 161 | + auto& valid_places = encoder_op->valid_places(); |
| 162 | + xpu_mask_adaptive_op->Attach(op_desc, scope); |
| 163 | + auto* xpu_mask_adaptive_node = |
| 164 | + graph->GraphCreateInstructNode(xpu_mask_adaptive_op, valid_places); |
| 165 | + |
| 166 | + // create sequence_unpad to pack the encoder input |
| 167 | + cpp::OpDesc sequence_unpad_op_desc; |
| 168 | + sequence_unpad_op_desc.SetType("sequence_unpad"); |
| 169 | + sequence_unpad_op_desc.SetInput("X", |
| 170 | + {matched.at("encoder_input")->arg()->name}); |
| 171 | + sequence_unpad_op_desc.SetInput("Length", {xpu_mask_adaptive_seq_len_name}); |
| 172 | + sequence_unpad_op_desc.SetOutput("Out", {packed_encoder_input_name}); |
| 173 | + auto sequence_unpad_op = LiteOpRegistry::Global().Create("sequence_unpad"); |
| 174 | + sequence_unpad_op->Attach(sequence_unpad_op_desc, scope); |
| 175 | + auto* sequence_unpad_node = |
| 176 | + graph->GraphCreateInstructNode(sequence_unpad_op, valid_places); |
| 177 | + |
| 178 | + encoder_op_desc->SetInput("Input", {packed_encoder_input_name}); |
| 179 | + encoder_op_desc->SetInput("SeqLod", {xpu_mask_adaptive_seq_lod_name}); |
| 180 | + encoder_op_desc->SetInput("PadSeqLen", |
| 181 | + {xpu_mask_adaptive_pad_seq_len_name}); |
| 182 | + auto updated_encoder_op_desc = *encoder_instruct->mutable_op_info(); |
| 183 | + encoder_instruct->ResetOp(updated_encoder_op_desc, valid_places); |
| 184 | + |
| 185 | + RemoveDirectedLink(matched.at("encoder_input"), matched.at("xpu_encoder")); |
| 186 | + DirectedLink(matched.at("mask"), xpu_mask_adaptive_node); |
| 187 | + DirectedLink(xpu_mask_adaptive_node, xpu_mask_adaptive_seq_lod_node); |
| 188 | + DirectedLink(xpu_mask_adaptive_node, xpu_mask_adaptive_pad_seq_len_node); |
| 189 | + DirectedLink(xpu_mask_adaptive_node, xpu_mask_adaptive_seq_len_node); |
| 190 | + DirectedLink(xpu_mask_adaptive_seq_lod_node, matched.at("xpu_encoder")); |
| 191 | + DirectedLink(xpu_mask_adaptive_pad_seq_len_node, matched.at("xpu_encoder")); |
| 192 | + DirectedLink(xpu_mask_adaptive_seq_len_node, sequence_unpad_node); |
| 193 | + DirectedLink(matched.at("encoder_input"), sequence_unpad_node); |
| 194 | + DirectedLink(sequence_unpad_node, packed_encoder_input_node); |
| 195 | + DirectedLink(packed_encoder_input_node, matched.at("xpu_encoder")); |
| 196 | + } |
| 197 | + |
| 198 | + private: |
| 199 | + std::string matmul_type_; |
| 200 | +}; |
| 201 | + |
| 202 | +} // namespace fusion |
| 203 | + |
| 204 | +class XPUMultiEncoderAdaptiveSeqlenV3FusePass : public ProgramPass { |
| 205 | + public: |
| 206 | + void Apply(const std::unique_ptr<SSAGraph>& graph) override { |
| 207 | + std::vector<std::string> matmul_types{"matmul", "matmul_v2"}; |
| 208 | + for (auto& matmul_type : matmul_types) { |
| 209 | + fusion::XPUMultiEncoderAdaptiveSeqlenV3Fuser fuser(matmul_type); |
| 210 | + fuser(graph.get()); |
| 211 | + } |
| 212 | + } |
| 213 | +}; |
| 214 | + |
| 215 | +} // namespace mir |
| 216 | +} // namespace lite |
| 217 | +} // namespace paddle |
| 218 | + |
| 219 | +REGISTER_MIR_PASS(__xpu__multi_encoder_adaptive_seqlen_v3_fuse_pass, |
| 220 | + paddle::lite::mir::XPUMultiEncoderAdaptiveSeqlenV3FusePass) |
| 221 | + .BindTargets({TARGET(kXPU)}); |
0 commit comments