Skip to content

Commit 663ed16

Browse files
authored
[ARM] Add silu op (#9280)
* add silu op and python unitest
1 parent 2ff6f96 commit 663ed16

File tree

12 files changed

+239
-1
lines changed

12 files changed

+239
-1
lines changed

lite/api/paddle_place.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ const std::string& ActivationTypeToStr(ActivationType act) {
5252
"PRelu",
5353
"LeakyRelu",
5454
"Sigmoid",
55+
"Silu",
5556
"Tanh",
5657
"Swish",
5758
"Exp",

lite/api/paddle_place.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ enum class ActivationType : int {
143143
kSign = 20,
144144
kSoftPlus = 21,
145145
kMish = 22,
146-
NUM = 23,
146+
kSilu = 23,
147+
NUM = 24,
147148
};
148149

149150
static size_t PrecisionTypeLength(PrecisionType type) {

lite/backends/arm/math/activation.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,6 +1118,50 @@ void mish(const float* din, float* dout, int size, float threshold) {
11181118
dout[i] = x * std::tanh(sp);
11191119
}
11201120
}
1121+
1122+
template <>
1123+
void act_silu<float>(const float* din, float* dout, int size, int threads) {
1124+
int nums_per_thread = size / threads;
1125+
int remain = size - threads * nums_per_thread;
1126+
int neon_loop_cnt_dim4 = nums_per_thread >> 2;
1127+
int neon_loop_remain_dim4 = nums_per_thread - (neon_loop_cnt_dim4 << 2);
1128+
1129+
// float32x4_t vzero = vdupq_n_f32(0.f);
1130+
LITE_PARALLEL_BEGIN(i, tid, threads) {
1131+
float32x4_t x_vec = vdupq_n_f32(0.0f);
1132+
float32x4_t exp_vec = vdupq_n_f32(0.0f);
1133+
float32x4_t recip = vdupq_n_f32(0.0f);
1134+
const float* ptr_in_thread = din + i * nums_per_thread;
1135+
float* ptr_out_thread = dout + i * nums_per_thread;
1136+
for (int k = 0; k < neon_loop_cnt_dim4; ++k) {
1137+
x_vec = vld1q_f32(ptr_in_thread);
1138+
exp_vec = exp_ps(vnegq_f32(x_vec));
1139+
exp_vec = vaddq_f32(exp_vec, vdupq_n_f32(1.0f));
1140+
recip = vrecpeq_f32(exp_vec);
1141+
// Using Newton-Raphson step for finding the reciprocal
1142+
recip = vmulq_f32(vrecpsq_f32(exp_vec, recip), recip);
1143+
recip = vmulq_f32(vrecpsq_f32(exp_vec, recip), recip);
1144+
recip = vmulq_f32(x_vec, recip);
1145+
vst1q_f32(ptr_out_thread, recip);
1146+
ptr_out_thread += 4;
1147+
ptr_in_thread += 4;
1148+
}
1149+
for (int j = 0; j < neon_loop_remain_dim4; ++j) {
1150+
ptr_out_thread[0] = ptr_in_thread[0] / (1 + expf(-ptr_in_thread[0]));
1151+
ptr_in_thread++;
1152+
ptr_out_thread++;
1153+
}
1154+
}
1155+
LITE_PARALLEL_END();
1156+
float* ptr_out = dout + threads * nums_per_thread;
1157+
const float* ptr_in = din + threads * nums_per_thread;
1158+
for (int j = 0; j < remain; ++j) {
1159+
ptr_out[0] = ptr_in[0] / (1 + expf(-ptr_in[0]));
1160+
ptr_in++;
1161+
ptr_out++;
1162+
}
1163+
}
1164+
11211165
} // namespace math
11221166
} // namespace arm
11231167
} // namespace lite

lite/backends/arm/math/activation.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ void softplus(const T* din, T* dout, int size, float beta, int threads);
111111
template <typename T>
112112
void mish(const T* din, T* dout, int size, float threshold);
113113

114+
template <typename T>
115+
void act_silu(const T* din, T* dout, int size, int threads);
116+
114117
} // namespace math
115118
} // namespace arm
116119
} // namespace lite

lite/kernels/arm/activation_compute.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,16 @@ void EluCompute::Run() {
169169
x_data, output_data, x_dims.production(), alpha, ctx.threads());
170170
}
171171

172+
void SiluCompute::Run() {
173+
auto& param = this->Param<param_t>();
174+
auto& ctx = this->ctx_->template As<ARMContext>();
175+
auto x_dims = param.X->dims();
176+
auto x_data = param.X->data<float>();
177+
auto output_data = param.Out->mutable_data<float>();
178+
lite::arm::math::act_silu<float>(
179+
x_data, output_data, x_dims.production(), ctx.threads());
180+
}
181+
172182
} // namespace arm
173183
} // namespace kernels
174184
} // namespace lite
@@ -276,3 +286,8 @@ REGISTER_LITE_KERNEL(
276286
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
277287
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
278288
.Finalize();
289+
REGISTER_LITE_KERNEL(
290+
silu, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::SiluCompute, def)
291+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
292+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
293+
.Finalize();

lite/kernels/arm/activation_compute.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@ class EluCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
9898
virtual ~EluCompute() = default;
9999
};
100100

101+
class SiluCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
102+
public:
103+
using param_t = operators::ActivationParam;
104+
105+
void Run() override;
106+
107+
virtual ~SiluCompute() = default;
108+
};
109+
101110
} // namespace arm
102111
} // namespace kernels
103112
} // namespace lite

lite/kernels/host/activation_compute.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,17 @@ void SoftplusCompute::Run() {
289289
}
290290
}
291291

292+
void SiluCompute::Run() {
293+
auto& param = this->Param<param_t>();
294+
CHECK(param.X);
295+
auto x_dims = param.X->dims();
296+
auto x_data = param.X->data<float>();
297+
auto output_data = param.Out->mutable_data<float>();
298+
for (int i = 0; i < x_dims.production(); i++) {
299+
output_data[i] = x_data[i] / (1 + std::exp(-x_data[i]));
300+
}
301+
}
302+
292303
} // namespace host
293304
} // namespace kernels
294305
} // namespace lite
@@ -435,3 +446,8 @@ REGISTER_LITE_KERNEL(softplus,
435446
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
436447
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
437448
.Finalize();
449+
REGISTER_LITE_KERNEL(
450+
silu, kHost, kFloat, kNCHW, paddle::lite::kernels::host::SiluCompute, def)
451+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
452+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
453+
.Finalize();

lite/kernels/host/activation_compute.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,15 @@ class SoftplusCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
203203
virtual ~SoftplusCompute() = default;
204204
};
205205

206+
class SiluCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
207+
public:
208+
using param_t = operators::ActivationParam;
209+
210+
void Run() override;
211+
212+
virtual ~SiluCompute() = default;
213+
};
214+
206215
} // namespace host
207216
} // namespace kernels
208217
} // namespace lite

lite/operators/activation_ops.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
5959
param_.hard_sigmoid_offset = opdesc.GetAttr<float>("offset");
6060
} else if (opdesc.Type() == "sigmoid") {
6161
param_.active_type = lite_api::ActivationType::kSigmoid;
62+
} else if (opdesc.Type() == "silu") {
63+
param_.active_type = lite_api::ActivationType::kSilu;
6264
} else if (opdesc.Type() == "tanh") {
6365
param_.active_type = lite_api::ActivationType::kTanh;
6466
} else if (opdesc.Type() == "exp") {
@@ -140,3 +142,4 @@ REGISTER_LITE_OP(thresholded_relu, paddle::lite::operators::ActivationOp);
140142
REGISTER_LITE_OP(elu, paddle::lite::operators::ActivationOp);
141143
REGISTER_LITE_OP(erf, paddle::lite::operators::ActivationOp);
142144
REGISTER_LITE_OP(softplus, paddle::lite::operators::ActivationOp);
145+
REGISTER_LITE_OP(silu, paddle::lite::operators::ActivationOp);

lite/operators/activation_ops.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ class ActivationOp : public OpLite {
109109
case lite_api::ActivationType::kSoftPlus:
110110
ch->macs = param_.X->numel();
111111
break;
112+
case lite_api::ActivationType::kSilu:
113+
ch->macs = param_.X->numel();
114+
break;
112115
default:
113116
LOG(FATAL) << "This Type of Activation:"
114117
<< static_cast<int>(param_.active_type)

0 commit comments

Comments
 (0)