Skip to content

Commit 1aa80c3

Browse files
[ARM NEON] Get rid of redundant instructions in ScalarQuantizer
Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
1 parent 4d06d70 commit 1aa80c3

1 file changed

Lines changed: 22 additions & 36 deletions

File tree

faiss/impl/ScalarQuantizer.cpp

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,7 @@ struct Codec8bit {
101101
}
102102
float32x4_t res1 = vld1q_f32(result);
103103
float32x4_t res2 = vld1q_f32(result + 4);
104-
float32x4x2_t res = vzipq_f32(res1, res2);
105-
return vuzpq_f32(res.val[0], res.val[1]);
104+
return {res1, res2};
106105
}
107106
#endif
108107
};
@@ -153,8 +152,7 @@ struct Codec4bit {
153152
}
154153
float32x4_t res1 = vld1q_f32(result);
155154
float32x4_t res2 = vld1q_f32(result + 4);
156-
float32x4x2_t res = vzipq_f32(res1, res2);
157-
return vuzpq_f32(res.val[0], res.val[1]);
155+
return {res1, res2};
158156
}
159157
#endif
160158
};
@@ -266,8 +264,7 @@ struct Codec6bit {
266264
}
267265
float32x4_t res1 = vld1q_f32(result);
268266
float32x4_t res2 = vld1q_f32(result + 4);
269-
float32x4x2_t res = vzipq_f32(res1, res2);
270-
return vuzpq_f32(res.val[0], res.val[1]);
267+
return {res1, res2};
271268
}
272269
#endif
273270
};
@@ -345,16 +342,14 @@ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
345342
FAISS_ALWAYS_INLINE float32x4x2_t
346343
reconstruct_8_components(const uint8_t* code, int i) const {
347344
float32x4x2_t xi = Codec::decode_8_components(code, i);
348-
float32x4x2_t res = vzipq_f32(
349-
vfmaq_f32(
345+
return {vfmaq_f32(
350346
vdupq_n_f32(this->vmin),
351347
xi.val[0],
352348
vdupq_n_f32(this->vdiff)),
353349
vfmaq_f32(
354350
vdupq_n_f32(this->vmin),
355351
xi.val[1],
356-
vdupq_n_f32(this->vdiff)));
357-
return vuzpq_f32(res.val[0], res.val[1]);
352+
vdupq_n_f32(this->vdiff))};
358353
}
359354
};
360355

@@ -431,10 +426,8 @@ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
431426
float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i);
432427
float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i);
433428

434-
float32x4x2_t res = vzipq_f32(
435-
vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]),
436-
vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1]));
437-
return vuzpq_f32(res.val[0], res.val[1]);
429+
return {vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]),
430+
vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1])};
438431
}
439432
};
440433

@@ -496,10 +489,9 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
496489

497490
FAISS_ALWAYS_INLINE float32x4x2_t
498491
reconstruct_8_components(const uint8_t* code, int i) const {
499-
uint16x4x2_t codei = vld2_u16((const uint16_t*)(code + 2 * i));
500-
return vzipq_f32(
501-
vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])),
502-
vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1])));
492+
uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i));
493+
return {vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])),
494+
vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))};
503495
}
504496
};
505497
#endif
@@ -568,8 +560,7 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
568560
}
569561
float32x4_t res1 = vld1q_f32(result);
570562
float32x4_t res2 = vld1q_f32(result + 4);
571-
float32x4x2_t res = vzipq_f32(res1, res2);
572-
return vuzpq_f32(res.val[0], res.val[1]);
563+
return {res1, res2};
573564
}
574565
};
575566

@@ -868,7 +859,7 @@ struct SimilarityL2<8> {
868859
float32x4x2_t accu8;
869860

870861
FAISS_ALWAYS_INLINE void begin_8() {
871-
accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f));
862+
accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)};
872863
yi = y;
873864
}
874865

@@ -882,8 +873,7 @@ struct SimilarityL2<8> {
882873
float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
883874
float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);
884875

885-
float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
886-
accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
876+
accu8 = {accu8_0, accu8_1};
887877
}
888878

889879
FAISS_ALWAYS_INLINE void add_8_components_2(
@@ -895,8 +885,7 @@ struct SimilarityL2<8> {
895885
float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
896886
float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);
897887

898-
float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
899-
accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
888+
accu8 = {accu8_0, accu8_1};
900889
}
901890

902891
FAISS_ALWAYS_INLINE float result_8() {
@@ -996,7 +985,7 @@ struct SimilarityIP<8> {
996985
float32x4x2_t accu8;
997986

998987
FAISS_ALWAYS_INLINE void begin_8() {
999-
accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f));
988+
accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)};
1000989
yi = y;
1001990
}
1002991

@@ -1006,28 +995,25 @@ struct SimilarityIP<8> {
1006995

1007996
float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]);
1008997
float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]);
1009-
float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
1010-
accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
998+
accu8 = {accu8_0, accu8_1};
1011999
}
10121000

10131001
FAISS_ALWAYS_INLINE void add_8_components_2(
10141002
float32x4x2_t x1,
10151003
float32x4x2_t x2) {
10161004
float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]);
10171005
float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]);
1018-
float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
1019-
accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
1006+
accu8 = {accu8_0, accu8_1};
10201007
}
10211008

10221009
FAISS_ALWAYS_INLINE float result_8() {
1023-
float32x4x2_t sum_tmp = vzipq_f32(
1010+
float32x4x2_t sum = {
10241011
vpaddq_f32(accu8.val[0], accu8.val[0]),
1025-
vpaddq_f32(accu8.val[1], accu8.val[1]));
1026-
float32x4x2_t sum = vuzpq_f32(sum_tmp.val[0], sum_tmp.val[1]);
1027-
float32x4x2_t sum2_tmp = vzipq_f32(
1012+
vpaddq_f32(accu8.val[1], accu8.val[1])};
1013+
1014+
float32x4x2_t sum2 = {
10281015
vpaddq_f32(sum.val[0], sum.val[0]),
1029-
vpaddq_f32(sum.val[1], sum.val[1]));
1030-
float32x4x2_t sum2 = vuzpq_f32(sum2_tmp.val[0], sum2_tmp.val[1]);
1016+
vpaddq_f32(sum.val[1], sum.val[1])};
10311017
return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0);
10321018
}
10331019
};

0 commit comments

Comments
 (0)