|
17 | 17 |
|
18 | 18 | namespace { |
19 | 19 |
|
20 | | -// Computes a horizontal sum over an __m256 register |
21 | | -inline float horizontal_sum(const __m256 reg) { |
22 | | - const __m256 h0 = _mm256_hadd_ps(reg, reg); |
23 | | - const __m256 h1 = _mm256_hadd_ps(h0, h0); |
24 | | - |
25 | | - // extract high and low __m128 regs from __m256 |
26 | | - const __m128 h2 = _mm256_extractf128_ps(h1, 1); |
27 | | - const __m128 h3 = _mm256_castps256_ps128(h1); |
28 | | - |
29 | | - // get a final hsum into all 4 regs |
30 | | - const __m128 h4 = _mm_add_ss(h2, h3); |
| 20 | +inline float horizontal_sum(const __m128 v) { |
| 21 | + const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); |
| 22 | + const __m128 v1 = _mm_add_ps(v, v0); |
| 23 | + __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); |
| 24 | + const __m128 v3 = _mm_add_ps(v1, v2); |
| 25 | + return _mm_cvtss_f32(v3); |
| 26 | +} |
31 | 27 |
|
32 | | - // extract f[0] from __m128 |
33 | | - const float hsum = _mm_cvtss_f32(h4); |
34 | | - return hsum; |
| 28 | +// Computes a horizontal sum over an __m256 register |
| 29 | +inline float horizontal_sum(const __m256 v) { |
| 30 | + const __m128 v0 = |
| 31 | + _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); |
| 32 | + return horizontal_sum(v0); |
35 | 33 | } |
36 | 34 |
|
37 | 35 | } // namespace |
|
0 commit comments