Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions lite/backends/arm/math/fp16/common_preprocess.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ typedef __fp16 float16_t;
const dtype *ptr_w5, const dtype *ptr_w6, const dtype *ptr_w7, \
int remain

#define PTR_ACQUIRE_PARAM_4(dtype) \
const dtype *ptr_zero, const dtype *ptr_w0, const dtype *ptr_w1, \
const dtype *ptr_w2, const dtype *ptr_w3, int remain

#define PTR_ACQUIRE_PARAM_A8(dtype) \
const dtype *zerobuff, const dtype *inptr1, const dtype *inptr2, \
const dtype *inptr3, const dtype *inptr4, const dtype *inptr5, \
Expand Down Expand Up @@ -164,28 +168,28 @@ typedef __fp16 float16_t;
}

inline void act_acquire(lite_api::ActivationType act,
int &flag_act, // NOLINT
float &local_alpha, // NOLINT
float &offset, // NOLINT
float &threshold, // NOLINT
int &flag_act, // NOLINT
float16_t &local_alpha, // NOLINT
float16_t &offset, // NOLINT
float16_t &threshold, // NOLINT
const operators::ActivationParam act_param) {
switch (act) {
case lite_api::ActivationType::kRelu:
flag_act = 0x01;
break;
case lite_api::ActivationType::kRelu6:
flag_act = 0x02;
local_alpha = act_param.Relu_clipped_coef;
local_alpha = static_cast<float16_t>(act_param.Relu_clipped_coef);
break;
case lite_api::ActivationType::kLeakyRelu:
flag_act = 0x03;
local_alpha = act_param.Leaky_relu_alpha;
local_alpha = static_cast<float16_t>(act_param.Leaky_relu_alpha);
break;
case lite_api::ActivationType::kHardSwish:
flag_act = 0x04;
local_alpha = 1.0 / act_param.hard_swish_scale;
offset = act_param.hard_swish_offset;
threshold = act_param.hard_swish_threshold;
local_alpha = static_cast<float16_t>(1.0 / act_param.hard_swish_scale);
offset = static_cast<float16_t>(act_param.hard_swish_offset);
threshold = static_cast<float16_t>(act_param.hard_swish_threshold);
break;
default:
break;
Expand Down Expand Up @@ -221,6 +225,23 @@ inline void ptr_acquire_remain(PTR_ACQUIRE_PARAM(dtype)) {
}
}

template <typename dtype>
inline void ptr_acquire_remain_four(PTR_ACQUIRE_PARAM_4(dtype)) {
switch (4 - remain) {
case 3:
ptr_w0 = ptr_zero;
break;
case 2:
ptr_w1 = ptr_zero;
break;
case 1:
ptr_w2 = ptr_zero;
break;
default:
break;
}
}

template <typename dtype>
inline void ptr_acquire_norm(PTR_ACQUIRE_PARAM(dtype)) {
switch (8 - remain) {
Expand All @@ -244,6 +265,20 @@ inline void ptr_acquire_norm(PTR_ACQUIRE_PARAM(dtype)) {
}
}

template <typename dtype>
inline void ptr_acquire_norm_four(PTR_ACQUIRE_PARAM_4(dtype)) {
switch (4 - remain) {
case 3:
ptr_w1 = ptr_zero;
case 2:
ptr_w2 = ptr_zero;
case 1:
ptr_w3 = ptr_zero;
break;
default:
break;
}
}
template <typename dtype>
inline void ptr_acquire_a8(PTR_ACQUIRE_PARAM_A8(dtype)) {
switch (numa - numb) {
Expand Down
12 changes: 6 additions & 6 deletions lite/backends/arm/math/fp16/conv3x3_winograd_fp16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ void conv_compute_2x2_3x3_fp16(const float16_t* input,
float16_t* g_trans_remain_tmp_data = g_trans_tmp_data + threads * 128;
bool flag_bias = (bias != nullptr);
auto act_type = act_param.active_type;
float local_alpha = 0.f;
float16_t local_alpha = 0.f;
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3
float offset = 0.f;
float threshold = 6.f;
float16_t offset = 0.f;
float16_t threshold = 6.f;

if (act_param.has_active) {
act_acquire(act_type, flag_act, local_alpha, offset, threshold, act_param);
Expand Down Expand Up @@ -377,11 +377,11 @@ void conv_compute_4x4_3x3_fp16(const float16_t* input,
float16_t* g_trans_tmp_data = g_tmp_data + threads * tmp_data_thread_stride;
float16_t* g_trans_remain_tmp_data = g_trans_tmp_data + threads * 288;
auto act_type = act_param.active_type;
float local_alpha = 0.f;
float16_t local_alpha = 0.f;
bool flag_bias = (bias != nullptr);
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3
float offset = 0.f;
float threshold = 6.f;
float16_t offset = 0.f;
float16_t threshold = 6.f;

if (act_param.has_active) {
act_acquire(act_type, flag_act, local_alpha, offset, threshold, act_param);
Expand Down
6 changes: 3 additions & 3 deletions lite/backends/arm/math/fp16/conv3x3s1_direct_fp16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -492,11 +492,11 @@ void conv_3x3s1_direct_fp16(const float16_t* i_data,
int out_row_stride = OUT_C_BLOCK * wout_round;
auto act_type = act_param.active_type;
bool flag_bias = param.bias != nullptr;
float alpha = 0.f;
float16_t alpha = 0.f;
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3

float offset = 0.f;
float threshold = 6.f;
float16_t offset = 0.f;
float16_t threshold = 6.f;

if (act_param.has_active) {
act_acquire(act_type, flag_act, alpha, offset, threshold, act_param);
Expand Down
12 changes: 6 additions & 6 deletions lite/backends/arm/math/fp16/conv3x3s2_direct_fp16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -626,10 +626,10 @@ void conv_3x3s2_direct_fp16(const float16_t* i_data,
int out_row_stride = OUT_C_BLOCK * wout_round;
auto act_type = act_param.active_type;
bool flag_bias = param.bias != nullptr;
float alpha = 0.f;
float16_t alpha = 0.f;
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3 hardswish:4
float offset = 0.f;
float threshold = 6.f;
float16_t offset = 0.f;
float16_t threshold = 6.f;

if (act_param.has_active) {
act_acquire(act_type, flag_act, alpha, offset, threshold, act_param);
Expand Down Expand Up @@ -788,10 +788,10 @@ void conv_3x3s2_direct_fp16_c3(const float16_t* i_data,
int out_row_stride = OUT_C_BLOCK * wout_round;
auto act_type = act_param.active_type;
bool flag_bias = param.bias != nullptr;
float alpha = 0.f;
float16_t alpha = 0.f;
int flag_act = 0x00;
float offset = 0.f;
float threshold = 6.f;
float16_t offset = 0.f;
float16_t threshold = 6.f;

if (act_param.has_active) {
act_acquire(act_type, flag_act, alpha, offset, threshold, act_param);
Expand Down
24 changes: 13 additions & 11 deletions lite/backends/arm/math/fp16/gemm_fp16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,8 @@ void prepackA_trans_8x16(float16_t *out,
"3: \n"
"subs %[cnt], #1 \n"
"vst1.16 {d0-d3}, [%[outptr]]! \n"
"vst1.16 {d4-d7}, [%[outptr]]! \n"
"vst1.16 {d4-d7}, [%[outptr]] \n"
"sub %[outptr], #32 \n"
"add %[outptr], %[stride] \n"
"bne 0b \n"
"1: \n"
Expand All @@ -817,7 +818,7 @@ void prepackA_trans_8x16(float16_t *out,
"vmul.f16 q3, q3, %q[valpha] \n"
"4: \n"
"vst1.16 {d0-d3}, [%[outptr]]! \n"
"vst1.16 {d4-d7}, [%[outptr]]! \n"
"vst1.16 {d4-d7}, [%[outptr]] \n"
"2: \n"
: [ptr0] "+r"(ptr0),
[ptr1] "+r"(ptr1),
Expand Down Expand Up @@ -850,7 +851,7 @@ void prepackA_trans_8x16(float16_t *out,
"vmul.f16 q0, q0, %q[valpha] \n"
"3: \n"
"subs %[cnt], #1 \n"
"vst1.16 {d0-d1}, [%[outptr]]! \n"
"vst1.16 {d0-d1}, [%[outptr]] \n"
"add %[outptr], %[stride] \n"
"bne 0b \n"
"1: \n"
Expand All @@ -862,7 +863,7 @@ void prepackA_trans_8x16(float16_t *out,
"bne 4f \n"
"vmul.f16 q0, q0, %q[valpha] \n"
"4: \n"
"vst1.16 {d0-d1}, [%[outptr]]! \n"
"vst1.16 {d0-d1}, [%[outptr]] \n"
"2: \n"
: [ptr0] "+r"(ptr0), [outptr] "+r"(outptr_row_col), [cnt] "+r"(cnt_col)
: [right_remain] "r"(right_remain),
Expand All @@ -875,6 +876,7 @@ void prepackA_trans_8x16(float16_t *out,
}
LITE_PARALLEL_COMMON_END();
}

#endif

/**
Expand Down Expand Up @@ -1480,7 +1482,7 @@ void loadb_trans(float16_t *out,
"vld1.16 {d14-d15}, [%[inptr7]]!\n"
// c0d0c2d2c4d4c6d6
"vtrn.16 q2, q3 \n"
"vld1.16 {d16-d17}, [%[inptr8]]!\n"
"vld1.16 {d16-d17}, [%[inptr8]]\n"
// e0f0e2f2...
"vtrn.16 q4, q5 \n"
"add %[inptr8], %[stride_w]\n"
Expand Down Expand Up @@ -1874,9 +1876,9 @@ void gemm_prepack_8x16(bool is_transB,
llc_size = llc_size * 9 / 10;

auto act_type = act_param.active_type;
float local_alpha = 0.f;
float offset = 0.f;
float threshold = 6.f;
float16_t local_alpha = 0.f;
float16_t offset = 0.f;
float16_t threshold = 6.f;
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3
if (act_param.has_active) {
act_acquire(act_type, flag_act, local_alpha, offset, threshold, act_param);
Expand Down Expand Up @@ -2634,9 +2636,9 @@ void gemm_prepack_8x12(bool is_transB,
llc_size = llc_size * 9 / 10;

auto act_type = act_param.active_type;
float local_alpha = 0.f;
float offset = 0.f;
float threshold = 6.f;
float16_t local_alpha = 0.f;
float16_t offset = 0.f;
float16_t threshold = 6.f;
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3
if (act_param.has_active) {
act_acquire(act_type, flag_act, local_alpha, offset, threshold, act_param);
Expand Down
Loading