Skip to content

Commit 8605f6b

Browse files
committed
[XPU][Cherry-pick] support xpu mask adaptive output without embedding (PaddlePaddle#9970)
1 parent c0af63c commit 8605f6b

11 files changed

Lines changed: 458 additions & 4 deletions

lite/api/paddle_use_passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ USE_MIR_PASS(__xpu__conv2d_affine_channel_fuse_pass);
8787
USE_MIR_PASS(__xpu__conv2d_fuse_pass);
8888
USE_MIR_PASS(__xpu__softmax_topk_fuse_pass);
8989
USE_MIR_PASS(__xpu__multi_encoder_adaptive_seqlen_fuse_pass);
90+
USE_MIR_PASS(__xpu__multi_encoder_adaptive_seqlen_v3_fuse_pass);
9091
USE_MIR_PASS(__xpu__roformer_relative_pos_fuse_pass);
9192
USE_MIR_PASS(__xpu__multi_encoder_slice_link_fuse_pass);
9293
USE_MIR_PASS(__xpu__generate_sequence_fuse_pass);
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <memory>
16+
#include <string>
17+
#include "lite/backends/xpu/math.h"
18+
#include "lite/core/optimizer/mir/pass_registry.h"
19+
#include "lite/core/optimizer/mir/pattern_matcher_high_api.h"
20+
21+
namespace paddle {
22+
namespace lite {
23+
namespace mir {
24+
namespace fusion {
25+
26+
/* support adaptive seq len for mrc */
27+
/* in_Input in_Mask */
28+
/* | | */
29+
/* | | */
30+
/* | matmul */
31+
/* | | */
32+
/* | | */
33+
/* | scale */
34+
/* | / */
35+
/* | stack */
36+
/* | | */
37+
/* | / */
38+
/* | / */
39+
/* xpu_encoder */
40+
/* | */
41+
/* | */
42+
/* out_Output */
43+
/*-------------------------------------------*/
44+
/* After the pass apply: */
45+
/* in_Input in_Mask */
46+
/* | | */
47+
/* | | */
48+
/* | xpu_adaptive_mask */
49+
/* | | | */
50+
/* sequence_unpad<--Lenght | */
51+
/* | | */
52+
/* | PadSeqLen */
53+
/* | SeqLod */
54+
/* | / */
55+
/* | / */
56+
/* | / */
57+
/* xpu_encoder */
58+
/* | */
59+
/* | */
60+
/* out_Output */
61+
/*-------------------------------------------*/
62+
63+
class XPUMultiEncoderAdaptiveSeqlenV3Fuser : public FuseBase {
64+
public:
65+
explicit XPUMultiEncoderAdaptiveSeqlenV3Fuser(
66+
const std::string& matmul_type = "matmul")
67+
: matmul_type_(matmul_type) {}
68+
69+
void BuildPattern() override {
70+
auto* mask = VarNode("mask")
71+
->assert_is_op_input(matmul_type_, "X")
72+
->assert_is_op_input(matmul_type_, "Y");
73+
auto* matmul = OpNode(matmul_type_, matmul_type_)->AsIntermediate();
74+
auto* matmul_out = VarNode("matmul_out")
75+
->assert_is_op_input("scale", "X")
76+
->assert_is_op_output(matmul_type_, "Out")
77+
->AsIntermediate();
78+
auto* scale =
79+
OpNode("scale", "scale")
80+
->assert_op_attr<bool>("bias_after_scale", false)
81+
->assert_op_attr_satisfied<float>(
82+
"bias",
83+
[](float attr) { return (std::fabs(attr + 1.0) < 1e-5); })
84+
->assert_op_attr_satisfied<float>(
85+
"scale",
86+
[](float attr) { return (std::fabs(attr - 10000.0) < 1e-5); })
87+
->AsIntermediate();
88+
auto* scale_out = VarNode("scale_out")
89+
->assert_is_op_input("stack", "X")
90+
->assert_is_op_output("scale", "Out")
91+
->AsIntermediate();
92+
auto* stack = OpNode("stack", "stack")->AsIntermediate();
93+
auto* stack_out = VarNode("stack_out")
94+
->assert_is_op_input("__xpu__multi_encoder", "Mask")
95+
->assert_is_op_output("stack", "Y")
96+
->AsIntermediate();
97+
auto* encoder_input =
98+
VarNode("encoder_input")
99+
->assert_is_op_input("__xpu__multi_encoder", "Input");
100+
auto* xpu_encoder = OpNode("xpu_encoder", "__xpu__multi_encoder")
101+
->assert_op_attr<bool>("adaptive_seqlen", true);
102+
103+
*mask >> *matmul >> *matmul_out >> *scale >> *scale_out >> *stack >>
104+
*stack_out >> *xpu_encoder;
105+
*encoder_input >> *xpu_encoder;
106+
}
107+
108+
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
109+
auto* encoder_instruct = matched.at("xpu_encoder")->stmt();
110+
auto encoder_op_desc = encoder_instruct->mutable_op_info();
111+
auto encoder_op = encoder_instruct->op();
112+
auto* scope = encoder_op->scope();
113+
114+
// add new arg seq_lod
115+
std::string stack_out_name = matched.at("stack_out")->arg()->name;
116+
std::string xpu_mask_adaptive_seq_lod_name = stack_out_name + "_seq_lod";
117+
auto* xpu_mask_adaptive_seq_lod_node =
118+
graph->NewArgumentNode(xpu_mask_adaptive_seq_lod_name);
119+
xpu_mask_adaptive_seq_lod_node->arg()->type = LiteType::GetTensorTy(
120+
TARGET(kHost), PRECISION(kInt32), DATALAYOUT(kNCHW));
121+
scope->NewTensor(xpu_mask_adaptive_seq_lod_name);
122+
// add new arg pad_seq_len, store max padded length
123+
std::string xpu_mask_adaptive_pad_seq_len_name =
124+
stack_out_name + "_pad_seq_len";
125+
auto* xpu_mask_adaptive_pad_seq_len_node =
126+
graph->NewArgumentNode(xpu_mask_adaptive_pad_seq_len_name);
127+
xpu_mask_adaptive_pad_seq_len_node->arg()->type = LiteType::GetTensorTy(
128+
TARGET(kHost), PRECISION(kInt32), DATALAYOUT(kNCHW));
129+
scope->NewTensor(xpu_mask_adaptive_pad_seq_len_name);
130+
// add new arg length, for sequence_unpad, store length in batch
131+
std::string xpu_mask_adaptive_seq_len_name = stack_out_name + "_seq_length";
132+
auto* xpu_mask_adaptive_seq_len_node =
133+
graph->NewArgumentNode(xpu_mask_adaptive_seq_len_name);
134+
xpu_mask_adaptive_seq_len_node->arg()->type = LiteType::GetTensorTy(
135+
TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kNCHW));
136+
scope->NewTensor(xpu_mask_adaptive_seq_len_name);
137+
138+
// add new packed input of encoder
139+
std::string orig_encoder_input_name =
140+
matched.at("encoder_input")->arg()->name;
141+
std::string packed_encoder_input_name =
142+
orig_encoder_input_name + "_vsl_packed";
143+
auto* packed_encoder_input_node =
144+
graph->NewArgumentNode(packed_encoder_input_name);
145+
packed_encoder_input_node->arg()->type = LiteType::GetTensorTy(
146+
TARGET(kXPU), PRECISION(kFloat), DATALAYOUT(kNCHW));
147+
scope->NewTensor(packed_encoder_input_name);
148+
149+
// create xpu_mask_adaptive op to set lod
150+
cpp::OpDesc op_desc;
151+
op_desc.SetType("__xpu__mask_adaptive");
152+
op_desc.SetInput("Mask", {matched.at("mask")->arg()->name});
153+
op_desc.SetOutput(
154+
"Length",
155+
{xpu_mask_adaptive_seq_len_name}); // length for sequence_unpad op
156+
op_desc.SetOutput("SeqLod",
157+
{xpu_mask_adaptive_seq_lod_name}); // lod for encoder op
158+
op_desc.SetOutput("PadSeqLen", {xpu_mask_adaptive_pad_seq_len_name});
159+
auto xpu_mask_adaptive_op =
160+
LiteOpRegistry::Global().Create("__xpu__mask_adaptive");
161+
auto& valid_places = encoder_op->valid_places();
162+
xpu_mask_adaptive_op->Attach(op_desc, scope);
163+
auto* xpu_mask_adaptive_node =
164+
graph->GraphCreateInstructNode(xpu_mask_adaptive_op, valid_places);
165+
166+
// create sequence_unpad to pack the encoder input
167+
cpp::OpDesc sequence_unpad_op_desc;
168+
sequence_unpad_op_desc.SetType("sequence_unpad");
169+
sequence_unpad_op_desc.SetInput("X",
170+
{matched.at("encoder_input")->arg()->name});
171+
sequence_unpad_op_desc.SetInput("Length", {xpu_mask_adaptive_seq_len_name});
172+
sequence_unpad_op_desc.SetOutput("Out", {packed_encoder_input_name});
173+
auto sequence_unpad_op = LiteOpRegistry::Global().Create("sequence_unpad");
174+
sequence_unpad_op->Attach(sequence_unpad_op_desc, scope);
175+
auto* sequence_unpad_node =
176+
graph->GraphCreateInstructNode(sequence_unpad_op, valid_places);
177+
178+
encoder_op_desc->SetInput("Input", {packed_encoder_input_name});
179+
encoder_op_desc->SetInput("SeqLod", {xpu_mask_adaptive_seq_lod_name});
180+
encoder_op_desc->SetInput("PadSeqLen",
181+
{xpu_mask_adaptive_pad_seq_len_name});
182+
auto updated_encoder_op_desc = *encoder_instruct->mutable_op_info();
183+
encoder_instruct->ResetOp(updated_encoder_op_desc, valid_places);
184+
185+
RemoveDirectedLink(matched.at("encoder_input"), matched.at("xpu_encoder"));
186+
DirectedLink(matched.at("mask"), xpu_mask_adaptive_node);
187+
DirectedLink(xpu_mask_adaptive_node, xpu_mask_adaptive_seq_lod_node);
188+
DirectedLink(xpu_mask_adaptive_node, xpu_mask_adaptive_pad_seq_len_node);
189+
DirectedLink(xpu_mask_adaptive_node, xpu_mask_adaptive_seq_len_node);
190+
DirectedLink(xpu_mask_adaptive_seq_lod_node, matched.at("xpu_encoder"));
191+
DirectedLink(xpu_mask_adaptive_pad_seq_len_node, matched.at("xpu_encoder"));
192+
DirectedLink(xpu_mask_adaptive_seq_len_node, sequence_unpad_node);
193+
DirectedLink(matched.at("encoder_input"), sequence_unpad_node);
194+
DirectedLink(sequence_unpad_node, packed_encoder_input_node);
195+
DirectedLink(packed_encoder_input_node, matched.at("xpu_encoder"));
196+
}
197+
198+
private:
199+
std::string matmul_type_;
200+
};
201+
202+
} // namespace fusion
203+
204+
class XPUMultiEncoderAdaptiveSeqlenV3FusePass : public ProgramPass {
205+
public:
206+
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
207+
std::vector<std::string> matmul_types{"matmul", "matmul_v2"};
208+
for (auto& matmul_type : matmul_types) {
209+
fusion::XPUMultiEncoderAdaptiveSeqlenV3Fuser fuser(matmul_type);
210+
fuser(graph.get());
211+
}
212+
}
213+
};
214+
215+
} // namespace mir
216+
} // namespace lite
217+
} // namespace paddle
218+
219+
REGISTER_MIR_PASS(__xpu__multi_encoder_adaptive_seqlen_v3_fuse_pass,
220+
paddle::lite::mir::XPUMultiEncoderAdaptiveSeqlenV3FusePass)
221+
.BindTargets({TARGET(kXPU)});

lite/core/optimizer/optimizer.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ std::unique_ptr<RuntimeProgram> RunDefaultOptimizer(
205205
"__xpu__fc_fuse_pass",
206206
"__xpu__softmax_topk_fuse_pass",
207207
"__xpu__multi_encoder_adaptive_seqlen_fuse_pass",
208+
"__xpu__multi_encoder_adaptive_seqlen_v3_fuse_pass",
208209
"__xpu__multi_encoder_slice_link_fuse_pass",
209210
"__xpu__generate_sequence_fuse_pass",
210211
"__xpu__logit_fuse_pass",

lite/kernels/xpu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ add_kernel(deformable_conv_compute_xpu XPU extra SRCS deformable_conv_compute.cc
110110
add_kernel(__xpu__resnet50_compute_xpu XPU extra SRCS __xpu__resnet50_compute.cc)
111111
add_kernel(__xpu__multi_encoder_compute_xpu XPU extra SRCS __xpu__multi_encoder_compute.cc)
112112
add_kernel(__xpu__embedding_with_eltwise_add_compute_xpu XPU extra SRCS __xpu__embedding_with_eltwise_add_compute.cc)
113+
add_kernel(__xpu__mask_adaptive_compute_xpu XPU extra SRCS __xpu__mask_adaptive_compute.cc)
113114
add_kernel(__xpu__fc_compute_xpu XPU extra SRCS __xpu__fc_compute.cc)
114115
add_kernel(__xpu__search_attention_compute_xpu XPU extra SRCS __xpu__search_attention_compute.cc)
115116
add_kernel(__xpu__search_attention_2_compute_xpu XPU extra SRCS __xpu__search_attention_2_compute.cc)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/xpu/__xpu__mask_adaptive_compute.h"
16+
#include <vector>
17+
#include "lite/core/op_registry.h"
18+
19+
namespace paddle {
20+
namespace lite {
21+
namespace kernels {
22+
namespace xpu {
23+
24+
void XPUMaskAdaptiveCompute::Run() {
25+
auto& param = this->template Param<param_t>();
26+
CHECK(param.Mask && param.Mask->data<float>()) << "mask null";
27+
auto& mask_dims = param.Mask->dims();
28+
auto batch_size = mask_dims[0];
29+
auto pad_seq_len = mask_dims[1];
30+
param.PadSeqLen->mutable_data<int>()[0] = pad_seq_len;
31+
auto* seq_lod = param.SeqLod;
32+
seq_lod->Resize({batch_size + 1});
33+
std::vector<int> cpu_seq_lod{0};
34+
auto* seq_len = param.Length;
35+
seq_len->Resize({batch_size});
36+
std::vector<int64_t> cpu_seq_lens;
37+
38+
const float* mask_ptr = param.Mask->data<float>();
39+
40+
for (auto batch_idx = 0; batch_idx < batch_size; batch_idx++) {
41+
int cur_batch_seq_len = 0;
42+
for (auto seq_idx = 0; seq_idx < pad_seq_len; seq_idx++) {
43+
if (mask_ptr[batch_idx * pad_seq_len + seq_idx] > 1e-7) {
44+
cur_batch_seq_len += 1;
45+
} else {
46+
break;
47+
}
48+
}
49+
CHECK_GT(cur_batch_seq_len, 0);
50+
cpu_seq_lod.push_back(cpu_seq_lod.back() + cur_batch_seq_len);
51+
cpu_seq_lens.push_back(cur_batch_seq_len);
52+
}
53+
auto* seq_lod_ptr = seq_lod->mutable_data<int>();
54+
memcpy(seq_lod_ptr, cpu_seq_lod.data(), cpu_seq_lod.size() * sizeof(int));
55+
auto* seq_lens_ptr = seq_len->mutable_data<int64_t>();
56+
memcpy(
57+
seq_lens_ptr, cpu_seq_lens.data(), cpu_seq_lens.size() * sizeof(int64_t));
58+
}
59+
60+
} // namespace xpu
61+
} // namespace kernels
62+
} // namespace lite
63+
} // namespace paddle
64+
65+
REGISTER_LITE_KERNEL(__xpu__mask_adaptive,
66+
kXPU,
67+
kFloat,
68+
kNCHW,
69+
paddle::lite::kernels::xpu::XPUMaskAdaptiveCompute,
70+
def)
71+
.BindInput("Mask",
72+
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat))})
73+
.BindOutput("SeqLod",
74+
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
75+
.BindOutput("PadSeqLen",
76+
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
77+
.BindOutput("Length",
78+
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
79+
.Finalize();
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#pragma once
15+
16+
#include "lite/core/kernel.h"
17+
18+
namespace paddle {
19+
namespace lite {
20+
namespace kernels {
21+
namespace xpu {
22+
23+
class XPUMaskAdaptiveCompute
24+
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
25+
public:
26+
using param_t = operators::XPUMaskAdaptiveParam;
27+
28+
void Run();
29+
virtual ~XPUMaskAdaptiveCompute() = default;
30+
};
31+
32+
} // namespace xpu
33+
} // namespace kernels
34+
} // namespace lite
35+
} // namespace paddle

lite/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ add_operator(__xpu__multi_encoder_op extra SRCS __xpu__multi_encoder_op.cc)
234234
add_operator(__xpu__embedding_with_eltwise_add_op extra SRCS __xpu__embedding_with_eltwise_add_op.cc)
235235
add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc)
236236
add_operator(__xpu__roformer_relative_embedding_op extra SRCS __xpu__roformer_relative_embedding_op.cc)
237+
add_operator(__xpu__mask_adaptive_op extra SRCS __xpu__mask_adaptive_op.cc)
237238
add_operator(__xpu__search_attention_op extra SRCS __xpu__search_attention_op.cc)
238239
add_operator(__xpu__mmdnn_op extra SRCS __xpu__mmdnn_op.cc)
239240
add_operator(__xpu__conv2d_op extra SRCS __xpu__conv2d_op.cc)

0 commit comments

Comments
 (0)