Skip to content

Commit 91b0fd3

Browse files
authored
[XPU] Add select_input_compute and fix bug in box_coder. (#9711)
1 parent eff8dc1 commit 91b0fd3

7 files changed

Lines changed: 119 additions & 4 deletions

File tree

lite/kernels/xpu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ add_kernel(shape_compute_xpu XPU extra SRCS shape_compute.cc)
104104
add_kernel(lod_array_length_compute_xpu XPU extra SRCS lod_array_length_compute.cc)
105105
add_kernel(multiclass_nms_compute_xpu XPU extra SRCS multiclass_nms_compute.cc)
106106
add_kernel(lod_reset_compute_xpu XPU extra SRCS lod_reset_compute.cc)
107+
add_kernel(select_input_compute_xpu XPU extra SRCS select_input_compute.cc)
107108

108109
# extra(fused kernel)
109110
add_kernel(__xpu__resnet50_compute_xpu XPU extra SRCS __xpu__resnet50_compute.cc)

lite/kernels/xpu/__xpu__multi_encoder_compute.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,11 @@ void XPUMultiEncoderCompute::run_encoder(const T* in, T* out) {
306306
std::vector<int64_t> mask_shape = param.mask->dims().Vectorize();
307307
std::vector<int> encoder_mask_shape =
308308
std::vector<int>(mask_shape.begin(), mask_shape.end());
309-
CHECK_EQ(param.ffn_hidden_dim_scale, 4)
310-
<< "xpu don't support ffn_hidden_dim_scale!=4 when no vsl";
309+
// xpu1 don't support ffn_hidden_dim_scale!=4 when no vsl
310+
if (ctx.GetRawContext()->dev().type() == xdnn::kXPU1) {
311+
CHECK_EQ(param.ffn_hidden_dim_scale, 4)
312+
<< "xpu don't support ffn_hidden_dim_scale!=4 when no vsl";
313+
}
311314
xdnn::QKVAttnParam qkv_attn_param(batch,
312315
max_seqlen,
313316
param.head_num,
@@ -326,6 +329,7 @@ void XPUMultiEncoderCompute::run_encoder(const T* in, T* out) {
326329
qkv_attn_param.relative_pos.assign(roformer_embedding_.begin(),
327330
roformer_embedding_.end());
328331
}
332+
qkv_attn_param.scale_of_hidden_units = param.ffn_hidden_dim_scale;
329333
int r = xdnn::transformer_encoder<T, TW, TGEMM>(
330334
ctx.GetRawContext(),
331335
in,

lite/kernels/xpu/box_coder_compute.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ void BoxCoderCompute::Run() {
6666
output_box->Resize({row, col, len});
6767
auto* output = output_box->mutable_data<float>(TARGET(kXPU));
6868
float* variance_xpu_ptr =
69-
reinterpret_cast<float*>(variance_xpu_guard_->addr_);
69+
variance_xpu_guard_ ? reinterpret_cast<float*>(variance_xpu_guard_->addr_)
70+
: nullptr;
7071

7172
if (code_type == "encode_center_size") {
7273
int r = xdnn::box_coder_encoder<float>(ctx.GetRawContext(),

lite/kernels/xpu/elementwise_compute.cc

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,11 @@ using SubFloat32 =
183183
using SubFloat16 = xpu::ElementwiseCompute<float16,
184184
xpu::SubFunctor<float16>,
185185
PRECISION(kFP16)>;
186-
186+
using SubInt32 =
187+
xpu::ElementwiseCompute<int, xpu::SubFunctor<int>, PRECISION(kFloat)>;
188+
using SubInt64 = xpu::ElementwiseCompute<int64_t,
189+
xpu::SubFunctor<int64_t>,
190+
PRECISION(kFloat)>;
187191
using MulFloat32 =
188192
xpu::ElementwiseCompute<float, xpu::MulFunctor<float>, PRECISION(kFloat)>;
189193
using MulFloat16 = xpu::ElementwiseCompute<float16,
@@ -273,6 +277,18 @@ REGISTER_LITE_KERNEL(
273277
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
274278
.Finalize();
275279

280+
REGISTER_LITE_KERNEL(elementwise_sub, kXPU, kFloat, kNCHW, SubInt32, int32)
281+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
282+
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
283+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
284+
.Finalize();
285+
286+
REGISTER_LITE_KERNEL(elementwise_sub, kXPU, kFloat, kNCHW, SubInt64, int64)
287+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
288+
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
289+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
290+
.Finalize();
291+
276292
REGISTER_LITE_KERNEL(elementwise_mul, kXPU, kFloat, kNCHW, MulFloat32, def)
277293
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
278294
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU))})
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/xpu/select_input_compute.h"
16+
#include "lite/backends/xpu/xpu_header_sitter.h"
17+
#include "lite/core/op_registry.h"
18+
19+
namespace paddle {
20+
namespace lite {
21+
namespace kernels {
22+
namespace xpu {
23+
24+
void SelectInputCompute::Run() {
25+
auto& param = this->Param<param_t>();
26+
auto& ctx = this->ctx_->template As<XPUContext>();
27+
auto x = param.X;
28+
29+
auto output = param.Out;
30+
auto x_i = x[*param.Mask->data<int>()];
31+
output->mutable_data(TARGET(kXPU), x_i->memory_size());
32+
int r = xdnn::copy<int8_t>(ctx.GetRawContext(),
33+
x_i->data<int8_t>(),
34+
reinterpret_cast<int8_t*>(output->raw_data()),
35+
x_i->memory_size());
36+
CHECK_EQ(r, 0);
37+
}
38+
39+
} // namespace xpu
40+
} // namespace kernels
41+
} // namespace lite
42+
} // namespace paddle
43+
44+
REGISTER_LITE_KERNEL(select_input,
45+
kXPU,
46+
kAny,
47+
kNCHW,
48+
paddle::lite::kernels::xpu::SelectInputCompute,
49+
def)
50+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kAny))})
51+
.BindInput("Mask",
52+
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
53+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kAny))})
54+
.Finalize();
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include <algorithm>
17+
#include <string>
18+
#include "lite/backends/xpu/target_wrapper.h"
19+
#include "lite/core/kernel.h"
20+
21+
namespace paddle {
22+
namespace lite {
23+
namespace kernels {
24+
namespace xpu {
25+
26+
class SelectInputCompute : public KernelLite<TARGET(kXPU), PRECISION(kAny)> {
27+
public:
28+
using param_t = operators::SelectInputParam;
29+
30+
void Run() override;
31+
32+
virtual ~SelectInputCompute() = default;
33+
};
34+
35+
} // namespace xpu
36+
} // namespace kernels
37+
} // namespace lite
38+
} // namespace paddle

lite/operators/select_input_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ bool SelectInputOpLite::InferShapeImpl() const {
3737
const auto &output_dims = inputs[Mask]->dims();
3838
// Set output dims
3939
param_.Out->Resize(output_dims);
40+
param_.Out->set_lod(inputs[Mask]->lod());
4041
return true;
4142
}
4243

0 commit comments

Comments
 (0)