Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 98 additions & 55 deletions lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,16 @@ void conv_compute_6x6_3x3(const float* input,
// trans output
for (int ci = 0; ci < oc_4; ++ci) {
if (param.bias) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
if (ci * 4 + 4 < chout) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加上这些 if 后,会不会影响性能?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

会的,但是没有这个判断,会有读越界风险。在精度和性能上只能选择精度

bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
} else {
for (int p = 0; p < 4 && ci * 4 + p < chout; p++) {
bias_value[p] = bias[ci * 4 + p];
}
}
}

float* dst_ci = dst_ptr + ci * oc_4_stride;
Expand Down Expand Up @@ -320,10 +326,16 @@ void conv_compute_6x6_3x3(const float* input,
} else {
for (int ci = 0; ci < oc_4; ++ci) {
if (param.bias) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
if (ci * 4 + 4 < chout) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
} else {
for (int p = 0; p < 4 && ci * 4 + p < chout; p++) {
bias_value[p] = bias[ci * 4 + p];
}
}
}
// trans output
float* dst_ci = dst_ptr + ci * oc_4_stride;
Expand Down Expand Up @@ -568,10 +580,16 @@ void conv_compute_4x4_3x3(const float* input,
// trans output
for (int ci = 0; ci < oc_4; ++ci) {
if (param.bias) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
if (ci * 4 + 4 < chout) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
} else {
for (int p = 0; p < 4 && ci * 4 + p < chout; p++) {
bias_value[p] = bias[ci * 4 + p];
}
}
}

float* dst_ci = dst_ptr + ci * oc_4_stride;
Expand Down Expand Up @@ -608,10 +626,16 @@ void conv_compute_4x4_3x3(const float* input,
} else {
for (int ci = 0; ci < oc_4; ++ci) {
if (param.bias) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
if (ci * 4 + 4 < chout) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
} else {
for (int p = 0; p < 4 && ci * 4 + p < chout; p++) {
bias_value[p] = bias[ci * 4 + p];
}
}
}
// trans output
float* dst_ci = dst_ptr + ci * oc_4_stride;
Expand Down Expand Up @@ -849,10 +873,16 @@ void conv_compute_2x2_3x3(const float* input,
// trans output
for (int ci = 0; ci < oc_4; ++ci) {
if (param.bias) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
if (ci * 4 + 4 < chout) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
} else {
for (int p = 0; p < 4 && ci * 4 + p < chout; p++) {
bias_value[p] = bias[ci * 4 + p];
}
}
}

float* dst_ci = dst_ptr + ci * oc_4_stride;
Expand Down Expand Up @@ -883,11 +913,18 @@ void conv_compute_2x2_3x3(const float* input,
} else {
for (int ci = 0; ci < oc_4; ++ci) {
if (param.bias) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
if (ci * 4 + 4 < chout) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
} else {
for (int p = 0; p < 4 && ci * 4 + p < chout; p++) {
bias_value[p] = bias[ci * 4 + p];
}
}
}

// trans output
float* dst_ci = dst_ptr + ci * oc_4_stride;
float* src_ci = src_ptr + ci * tile_count * 4;
Expand Down Expand Up @@ -1106,10 +1143,16 @@ void conv_compute_2x2_3x3_small(const float* input,
// trans output
for (int ci = 0; ci < oc_4; ++ci) {
if (param.bias) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
if (ci * 4 + 4 < chout) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
} else {
for (int p = 0; p < 4 && ci * 4 + p < chout; p++) {
bias_value[p] = bias[ci * 4 + p];
}
}
}

float* dst_ci = dst_ptr + ci * oc_4_stride;
Expand Down Expand Up @@ -1141,10 +1184,16 @@ void conv_compute_2x2_3x3_small(const float* input,
} else {
for (int ci = 0; ci < oc_4; ++ci) {
if (param.bias) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
if (ci * 4 + 4 < chout) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
} else {
for (int p = 0; p < 4 && ci * 4 + p < chout; p++) {
bias_value[p] = bias[ci * 4 + p];
}
}
}
// trans output
float* dst_ci = dst_ptr + ci * oc_4_stride;
Expand Down Expand Up @@ -1278,15 +1327,13 @@ void output_trans_c4_post_6x8(const float* src,
vaddq_f32(vaddq_f32(tmp135a, vmulq_n_f32(tmp135b, 32)),
vmulq_n_f32(tmp135c, 0.03125f)));

if (bias_value) {
float32x4_t bias = vld1q_f32(bias_value);
dest0 = vaddq_f32(dest0, bias);
dest1 = vaddq_f32(dest1, bias);
dest2 = vaddq_f32(dest2, bias);
dest3 = vaddq_f32(dest3, bias);
dest4 = vaddq_f32(dest4, bias);
dest5 = vaddq_f32(dest5, bias);
}
float32x4_t bias = vld1q_f32(bias_value);
dest0 = vaddq_f32(dest0, bias);
dest1 = vaddq_f32(dest1, bias);
dest2 = vaddq_f32(dest2, bias);
dest3 = vaddq_f32(dest3, bias);
dest4 = vaddq_f32(dest4, bias);
dest5 = vaddq_f32(dest5, bias);

vst1q_f32(dest, dest0);
vst1q_f32(dest + dest_stride, dest1);
Expand Down Expand Up @@ -1356,13 +1403,11 @@ void output_trans_c4_post_4x6(const float* src,
float32x4_t dest3 =
vaddq_f32(vaddq_f32(tmp13a, vmulq_n_f32(tmp13b, 8)), src5);

if (bias_value) {
float32x4_t bias = vld1q_f32(bias_value);
dest0 = vaddq_f32(dest0, bias);
dest1 = vaddq_f32(dest1, bias);
dest2 = vaddq_f32(dest2, bias);
dest3 = vaddq_f32(dest3, bias);
}
float32x4_t bias = vld1q_f32(bias_value);
dest0 = vaddq_f32(dest0, bias);
dest1 = vaddq_f32(dest1, bias);
dest2 = vaddq_f32(dest2, bias);
dest3 = vaddq_f32(dest3, bias);

vst1q_f32(dest, dest0);
vst1q_f32(dest + dest_stride, dest1);
Expand Down Expand Up @@ -1608,13 +1653,11 @@ void output_trans_c4_post_2x4(const float* src,
float32x4_t dest01 = vaddq_f32(vaddq_f32(dst10, dst11), dst12);
float32x4_t dest11 = vsubq_f32(vsubq_f32(dst11, dst12), dst13);

if (bias_value) {
float32x4_t bias = vld1q_f32(bias_value);
dest00 = vaddq_f32(dest00, bias);
dest10 = vaddq_f32(dest10, bias);
dest01 = vaddq_f32(dest01, bias);
dest11 = vaddq_f32(dest11, bias);
}
float32x4_t bias = vld1q_f32(bias_value);
dest00 = vaddq_f32(dest00, bias);
dest10 = vaddq_f32(dest10, bias);
dest01 = vaddq_f32(dest01, bias);
dest11 = vaddq_f32(dest11, bias);

vst1q_f32(dest, dest00);
vst1q_f32(dest + dest_stride, dest10);
Expand Down
13 changes: 1 addition & 12 deletions lite/backends/arm/math/conv3x3s1_depthwise_int8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -820,17 +820,6 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
const int8_t* block_inr2 = block_inr1 + in_len;

const int8_t* weight_c = weights + c * w_stride;
float bias_local[8] = {0, 0, 0, 0, 0, 0, 0, 0};
if (flag_bias) {
bias_local[0] = bias[c];
bias_local[1] = bias[c + 1];
bias_local[2] = bias[c + 2];
bias_local[3] = bias[c + 3];
bias_local[4] = bias[c + 4];
bias_local[5] = bias[c + 5];
bias_local[6] = bias[c + 6];
bias_local[7] = bias[c + 7];
}
#ifdef __aarch64__
int8x8_t vw0 = vld1_s8(weight_c);
int8x8_t vw1 = vld1_s8(weight_c + 8);
Expand Down Expand Up @@ -1132,7 +1121,7 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
wout,
flag_act,
alpha,
bias_local,
bias + c,
flag_bias,
ptr_write,
scale + c);
Expand Down
9 changes: 1 addition & 8 deletions lite/backends/arm/math/conv3x3s1_direct_int8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,6 @@ void conv_3x3s1_direct_int8(const int8_t* din,
const int8_t* block_inr3 = block_inr2 + in_len;

const int8_t* weight_c = weights + c * w_stride;
float bias_local[4] = {0, 0, 0, 0};
if (flag_bias) {
bias_local[0] = bias[c];
bias_local[1] = bias[c + 1];
bias_local[2] = bias[c + 2];
bias_local[3] = bias[c + 3];
}
memset(pre_out, 0, pre_out_size * sizeof(int32_t));
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) {
const int8_t* wc0 = weight_c;
Expand Down Expand Up @@ -477,7 +470,7 @@ void conv_3x3s1_direct_int8(const int8_t* din,
wout,
flag_act,
alpha,
bias_local,
bias + c,
flag_bias,
ptr_write,
scale + c);
Expand Down
13 changes: 1 addition & 12 deletions lite/backends/arm/math/conv3x3s2_depthwise_int8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,6 @@ void conv_depthwise_3x3s2_common_int8(Dtype* dout,
const int8_t* block_inr2 = block_inr1 + in_len;

const int8_t* weight_c = weights + c * w_stride;
float bias_local[8] = {0, 0, 0, 0, 0, 0, 0, 0};
if (flag_bias) {
bias_local[0] = bias[c];
bias_local[1] = bias[c + 1];
bias_local[2] = bias[c + 2];
bias_local[3] = bias[c + 3];
bias_local[4] = bias[c + 4];
bias_local[5] = bias[c + 5];
bias_local[6] = bias[c + 6];
bias_local[7] = bias[c + 7];
}
#ifdef __aarch64__
int8x8_t vw0 = vld1_s8(weight_c);
int8x8_t vw1 = vld1_s8(weight_c + 8);
Expand Down Expand Up @@ -453,7 +442,7 @@ void conv_depthwise_3x3s2_common_int8(Dtype* dout,
wout,
flag_act,
alpha,
bias_local,
bias + c,
flag_bias,
ptr_write,
scale + c);
Expand Down
23 changes: 2 additions & 21 deletions lite/backends/arm/math/conv3x3s2_direct_int8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,18 +174,6 @@ void conv_3x3s2_direct_int8(const int8_t* din,
const int8_t* block_inr4 = cblock_inr4;

const int8_t* weight_c = weights + c * w_stride;
float bias_local[8] = {0, 0, 0, 0, 0, 0, 0, 0};
if (flag_bias) {
bias_local[0] = bias[c];
bias_local[1] = bias[c + 1];
bias_local[2] = bias[c + 2];
bias_local[3] = bias[c + 3];
bias_local[4] = bias[c + 4];
bias_local[5] = bias[c + 5];
bias_local[6] = bias[c + 6];
bias_local[7] = bias[c + 7];
}

memset(pre_out, 0, pre_out_size * sizeof(int32_t));
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) {
const int8_t* wc0 = weight_c;
Expand Down Expand Up @@ -476,7 +464,7 @@ void conv_3x3s2_direct_int8(const int8_t* din,
wout,
flag_act,
alpha,
bias_local,
bias + c,
flag_bias,
ptr_write,
scale + c);
Expand Down Expand Up @@ -629,13 +617,6 @@ void conv_3x3s2_direct_int8(const int8_t* din,
const int8_t* block_inr1 = cblock_inr1;
const int8_t* block_inr2 = cblock_inr2;
const int8_t* weight_c = weights + c * w_stride;
float bias_local[4] = {0, 0, 0, 0};
if (flag_bias) {
bias_local[0] = bias[c];
bias_local[1] = bias[c + 1];
bias_local[2] = bias[c + 2];
bias_local[3] = bias[c + 3];
}
memset(pre_out, 0, pre_out_size * sizeof(int32_t));
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) {
const int8_t* wc0 = weight_c;
Expand Down Expand Up @@ -766,7 +747,7 @@ void conv_3x3s2_direct_int8(const int8_t* din,
wout,
flag_act,
alpha,
bias_local,
bias + c,
flag_bias,
ptr_write,
scale + c);
Expand Down
Loading