Skip to content

Commit 4bfa2d0

Browse files
authored
[XPU] add seq_softmax, seq_expand, lod_reset op in xpu (#9453)
1 parent 6b04c22 commit 4bfa2d0

16 files changed

Lines changed: 471 additions & 35 deletions

lite/kernels/xpu/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ add_kernel(sequence_reverse_compute_xpu XPU extra SRCS sequence_reverse_compute.
6767
add_kernel(sequence_concat_compute_xpu XPU extra SRCS sequence_concat_compute.cc)
6868
add_kernel(sequence_arithmetic_compute_xpu XPU extra SRCS sequence_arithmetic_compute.cc)
6969
add_kernel(sequence_pool_compute_xpu XPU extra SRCS sequence_pool_compute.cc)
70+
add_kernel(sequence_expand_compute_xpu XPU extra SRCS sequence_expand_compute.cc)
71+
add_kernel(sequence_softmax_compute_xpu XPU extra SRCS sequence_softmax_compute.cc)
7072
add_kernel(match_matrix_tensor_compute_xpu XPU extra SRCS match_matrix_tensor_compute.cc)
7173
add_kernel(var_conv_2d_compute_xpu XPU extra SRCS var_conv_2d_compute.cc)
7274
add_kernel(search_grnn_compute_xpu XPU extra SRCS search_grnn_compute.cc)
@@ -101,6 +103,7 @@ add_kernel(is_empty_compute_xpu XPU extra SRCS is_empty_compute.cc)
101103
add_kernel(shape_compute_xpu XPU extra SRCS shape_compute.cc)
102104
add_kernel(lod_array_length_compute_xpu XPU extra SRCS lod_array_length_compute.cc)
103105
add_kernel(multiclass_nms_compute_xpu XPU extra SRCS multiclass_nms_compute.cc)
106+
add_kernel(lod_reset_compute_xpu XPU extra SRCS lod_reset_compute.cc)
104107

105108
# extra(fused kernel)
106109
add_kernel(__xpu__resnet50_compute_xpu XPU extra SRCS __xpu__resnet50_compute.cc)

lite/kernels/xpu/gru_compute.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,18 @@ void GRUCompute::PrepareForRun() {
5656
paddle::lite::xpu::math::FindMaxAbs(weight_s1_ptr, weight_s1_len);
5757
weight_s2_abs_max_ =
5858
paddle::lite::xpu::math::FindMaxAbs(weight_s2_ptr, weight_s2_len);
59-
std::vector<float> weight_max_vector(8);
60-
for (int i = 0; i < 4; i++) {
59+
auto& ctx = this->ctx_->template As<XPUContext>();
60+
int max_ptr_size = ctx.GetRawContext()->max_ptr_size();
61+
std::vector<float> weight_max_vector(max_ptr_size * 2);
62+
for (int i = 0; i < max_ptr_size; i++) {
6163
weight_max_vector[i] = weight_s1_abs_max_;
62-
weight_max_vector[i + 4] = weight_s2_abs_max_;
64+
weight_max_vector[i + max_ptr_size] = weight_s2_abs_max_;
6365
}
64-
weight_max_guard_ = TargetWrapperXPU::MallocScratchPad(8 * sizeof(float));
66+
weight_max_guard_ =
67+
TargetWrapperXPU::MallocScratchPad(max_ptr_size * 2 * sizeof(float));
6568
XPU_CALL(xpu_memcpy(reinterpret_cast<float*>(weight_max_guard_->addr_),
6669
weight_max_vector.data(),
67-
8 * sizeof(float),
70+
max_ptr_size * 2 * sizeof(float),
6871
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
6972
// quant
7073
quant_weight_guard_ =

lite/kernels/xpu/gru_unit_compute.cc

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,19 @@ void GRUUnitCompute::PrepareForRun() {
4242
paddle::lite::xpu::math::FindMaxAbs(weight_s1_ptr, weight_s1_len);
4343
weight_s2_abs_max_ =
4444
paddle::lite::xpu::math::FindMaxAbs(weight_s2_ptr, weight_s2_len);
45-
std::vector<float> weight_max_vector(8);
46-
for (int i = 0; i < 4; i++) {
45+
46+
auto& ctx = this->ctx_->template As<XPUContext>();
47+
int max_ptr_size = ctx.GetRawContext()->max_ptr_size();
48+
std::vector<float> weight_max_vector(max_ptr_size * 2);
49+
for (int i = 0; i < max_ptr_size; i++) {
4750
weight_max_vector[i] = weight_s1_abs_max_;
48-
weight_max_vector[i + 4] = weight_s2_abs_max_;
51+
weight_max_vector[i + max_ptr_size] = weight_s2_abs_max_;
4952
}
50-
weight_max_guard_ = TargetWrapperXPU::MallocScratchPad(8 * sizeof(float));
53+
weight_max_guard_ =
54+
TargetWrapperXPU::MallocScratchPad(max_ptr_size * 2 * sizeof(float));
5155
XPU_CALL(xpu_memcpy(reinterpret_cast<float*>(weight_max_guard_->addr_),
5256
weight_max_vector.data(),
53-
8 * sizeof(float),
57+
max_ptr_size * 2 * sizeof(float),
5458
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
5559
// quant
5660
quant_weight_guard_ =
@@ -103,14 +107,14 @@ void GRUUnitCompute::Run() {
103107
const float* bias_ptr = (bias == nullptr) ? nullptr : bias->data<float>();
104108

105109
float* hidden_ptr = hidden->mutable_data<float>(TARGET(kXPU));
106-
107-
int ret = xdnn::gru_unit<float, int16_t, float, int16_t>(
110+
int ret = xdnn::gru_core<float, int16_t, float, int16_t>(
108111
ctx.GetRawContext(),
109112
input_ptr,
110113
hidden_prev_ptr,
111114
weight_ptr,
112115
hidden_ptr,
113116
batch_size,
117+
1,
114118
frame_size,
115119
nullptr,
116120
nullptr,
@@ -119,7 +123,9 @@ void GRUUnitCompute::Run() {
119123
bias_ptr,
120124
xdnn::Activation_t::TANH,
121125
xdnn::Activation_t::SIGMOID,
122-
origin_mode);
126+
origin_mode,
127+
false,
128+
false);
123129
CHECK_EQ(ret, 0) << "call xdnn::gru_unit failed!";
124130
}
125131

lite/kernels/xpu/layer_norm_compute.cc

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ namespace lite {
2121
namespace kernels {
2222
namespace xpu {
2323

24-
void LayerNormCompute::Run() {
24+
template <typename InType, PrecisionType PType>
25+
void LayerNormCompute<InType, PType>::Run() {
2526
auto& param = this->template Param<param_t>();
2627
auto& ctx = this->ctx_->template As<XPUContext>();
2728

@@ -30,16 +31,17 @@ void LayerNormCompute::Run() {
3031
auto matrix_dim = x_dims.Flatten2D(axis);
3132
float epsilon = param.epsilon;
3233

33-
int r = xdnn::layer_norm(ctx.GetRawContext(), /* context */
34-
param.X->data<float>(), /* in */
35-
param.Y->mutable_data<float>(TARGET(kXPU)), /* out */
36-
matrix_dim[0], /* m */
37-
matrix_dim[1], /* n */
38-
epsilon, /* epsilon */
39-
param.Scale->data<float>(), /* scale */
40-
param.Bias->data<float>(), /* bias */
41-
nullptr,
42-
nullptr);
34+
int r = xdnn::layer_norm<InType>(
35+
ctx.GetRawContext(), /* context */
36+
param.X->template data<InType>(), /* in */
37+
param.Y->template mutable_data<InType>(TARGET(kXPU)), /* out */
38+
matrix_dim[0], /* m */
39+
matrix_dim[1], /* n */
40+
epsilon, /* epsilon */
41+
param.Scale->template data<float>(), /* scale */
42+
param.Bias->template data<float>(), /* bias */
43+
nullptr,
44+
nullptr);
4345

4446
CHECK_EQ(r, 0);
4547
}
@@ -49,16 +51,25 @@ void LayerNormCompute::Run() {
4951
} // namespace lite
5052
} // namespace paddle
5153

52-
REGISTER_LITE_KERNEL(layer_norm,
53-
kXPU,
54-
kFloat,
55-
kNCHW,
56-
paddle::lite::kernels::xpu::LayerNormCompute,
57-
def)
54+
namespace xpu = paddle::lite::kernels::xpu;
55+
56+
using LayerNorm_FP32 = xpu::LayerNormCompute<float, PRECISION(kFloat)>;
57+
using LayerNorm_FP16 = xpu::LayerNormCompute<float16, PRECISION(kFP16)>;
58+
REGISTER_LITE_KERNEL(layer_norm, kXPU, kFloat, kNCHW, LayerNorm_FP32, def)
5859
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
5960
.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kXPU))})
6061
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))})
6162
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kXPU))})
6263
.BindOutput("Mean", {LiteType::GetTensorTy(TARGET(kXPU))})
6364
.BindOutput("Variance", {LiteType::GetTensorTy(TARGET(kXPU))})
6465
.Finalize();
66+
67+
REGISTER_LITE_KERNEL(layer_norm, kXPU, kFP16, kNCHW, LayerNorm_FP16, fp16)
68+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
69+
.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kXPU))})
70+
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))})
71+
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
72+
.BindOutput("Mean", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
73+
.BindOutput("Variance",
74+
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
75+
.Finalize();

lite/kernels/xpu/layer_norm_compute.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ namespace lite {
2121
namespace kernels {
2222
namespace xpu {
2323

24-
class LayerNormCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
24+
template <typename InType, PrecisionType PType>
25+
class LayerNormCompute : public KernelLite<TARGET(kXPU), PType> {
2526
public:
2627
using param_t = operators::LayerNormParam;
2728

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/xpu/lod_reset_compute.h"
16+
#include <algorithm>
17+
#include <vector>
18+
#include "lite/backends/xpu/xpu_header_sitter.h"
19+
#include "lite/core/op_registry.h"
20+
21+
namespace paddle {
22+
namespace lite {
23+
namespace kernels {
24+
namespace xpu {
25+
26+
void LodResetCompute::Run() {
27+
auto& param = this->Param<param_t>();
28+
auto& ctx = this->ctx_->template As<XPUContext>();
29+
30+
auto x = param.X;
31+
auto output = param.Out;
32+
output->mutable_data(TARGET(kXPU), x->memory_size());
33+
int r = xdnn::copy<int8_t>(ctx.GetRawContext(),
34+
x->data<int8_t>(),
35+
reinterpret_cast<int8_t*>(output->raw_data()),
36+
x->memory_size());
37+
CHECK_EQ(r, 0);
38+
auto lod = output->mutable_lod();
39+
if (param.Y) {
40+
if (param.Y->lod().size()) {
41+
*lod = param.Y->lod();
42+
} else {
43+
const auto* y_data = param.Y->data<int>();
44+
std::vector<int> y_cpu(param.Y->numel());
45+
TargetWrapperXPU::MemcpySync(y_cpu.data(),
46+
y_data,
47+
param.Y->numel() * sizeof(int),
48+
IoDirection::DtoH);
49+
(*lod).resize(1);
50+
(*lod)[0].resize(param.Y->numel());
51+
for (int i = 0; i < param.Y->numel(); i++) {
52+
(*lod)[0][i] = y_cpu[i];
53+
}
54+
}
55+
} else {
56+
(*lod).resize(1);
57+
for (auto id : param.target_lod) {
58+
(*lod)[0].push_back(id);
59+
}
60+
}
61+
}
62+
63+
} // namespace xpu
64+
} // namespace kernels
65+
} // namespace lite
66+
} // namespace paddle
67+
68+
REGISTER_LITE_KERNEL(lod_reset,
69+
kXPU,
70+
kAny,
71+
kNCHW,
72+
paddle::lite::kernels::xpu::LodResetCompute,
73+
def)
74+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kAny))})
75+
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kAny))})
76+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kAny))})
77+
.Finalize();
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include "lite/core/kernel.h"
17+
#include "lite/core/op_registry.h"
18+
19+
namespace paddle {
20+
namespace lite {
21+
namespace kernels {
22+
namespace xpu {
23+
24+
class LodResetCompute : public KernelLite<TARGET(kXPU), PRECISION(kAny)> {
25+
public:
26+
using param_t = operators::LodResetParam;
27+
28+
void Run() override;
29+
30+
virtual ~LodResetCompute() = default;
31+
};
32+
33+
} // namespace xpu
34+
} // namespace kernels
35+
} // namespace lite
36+
} // namespace paddle

0 commit comments

Comments
 (0)