@@ -126,14 +126,24 @@ class IndexSelectKernel : public framework::OpKernel<T> {
126126};
127127
128128#if ((!defined __NVCC__) && (!defined __HIPCC__))
129- template <typename T>
130- void index_sum (const size_t n, const T* src, T* dst) {
129+ template <typename T, platform::cpu_isa_t isa = platform::isa_any>
130+ void IndexSelectAdd (const int n, const T* src, T* dst) {
131+ for (auto k = 0 ; k < n; ++k) {
132+ dst[k] += src[k];
133+ }
134+ }
135+
136+ // description: Index addition uses intel intrinsic instruction set to read and
137+ // write data in parallel
138+ template <>
139+ void IndexSelectAdd<float , platform::avx>(const int n, const float * src,
140+ float * dst) {
131141#ifdef __AVX__
132142 constexpr int block = YMM_FLOAT_BLOCK;
133- unsigned int i, end;
134- i = end = 0 ;
135- end = n & ~(block - 1 );
143+ int i = 0 ;
144+ int end = n & ~(block - 1 );
136145 for (i = 0 ; i < end; i += block) {
146+ // Quote from https://software.intel.com/sites/landingpage/IntrinsicsGuide/
137147 _mm256_storeu_ps (reinterpret_cast <float *>(dst) + i,
138148 _mm256_add_ps (_mm256_loadu_ps ((const float *)dst + i),
139149 _mm256_loadu_ps ((const float *)src + i)));
@@ -142,19 +152,17 @@ void index_sum(const size_t n, const T* src, T* dst) {
142152 dst[i] += src[i];
143153 }
144154#else
145- for (size_t k = 0 ; k < n; k++) {
146- dst[k] += src[k];
147- }
155+ IndexSelectAdd<float , platform::isa_any>(n, src, dst);
148156#endif
149157}
150158
151159template <>
152- void index_sum (const size_t n, const double * src, double * dst) {
160+ void IndexSelectAdd<double , platform::avx>(const int n, const double * src,
161+ double * dst) {
153162#ifdef __AVX__
154163 constexpr int block = XMM_FLOAT_BLOCK;
155- unsigned int i, end;
156- i = end = 0 ;
157- end = n & ~(block - 1 );
164+ int i = 0 ;
165+ int end = n & ~(block - 1 );
158166 for (i = 0 ; i < end; i += block) {
159167 _mm256_storeu_pd (reinterpret_cast <double *>(dst) + i,
160168 _mm256_add_pd (_mm256_loadu_pd ((const double *)dst + i),
@@ -164,9 +172,7 @@ void index_sum(const size_t n, const double* src, double* dst) {
164172 dst[i] += src[i];
165173 }
166174#else
167- for (size_t k = 0 ; k < n; k++) {
168- dst[k] += src[k];
169- }
175+ IndexSelectAdd<double , platform::isa_any>(n, src, dst);
170176#endif
171177}
172178#endif
@@ -182,6 +188,7 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
182188 auto input_dim_size = input_dim.size ();
183189 auto output_dim = x_grad->dims ();
184190 std::memset (out_data, 0.0 , x_grad->numel () * sizeof (T));
191+
185192 auto slice_size = 1 ;
186193 for (auto i = dim + 1 ; i < input_dim_size; i++) {
187194 slice_size *= input_dim[i];
@@ -202,17 +209,15 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
202209 auto output_start_offset = i * output_width;
203210 for (auto j = 0 ; j < index_size; j++) {
204211 IndexT index_value = index_data[j];
205- #ifdef __AVX__
206212 auto src = input_data + input_start_offset + j * slice_size;
207213 auto dst = out_data + output_start_offset + index_value * slice_size;
214+
208215#if ((!defined __NVCC__) && (!defined __HIPCC__))
209- index_sum (slice_size, src, dst);
210- # endif
216+ # ifdef __AVX__
217+ index_select_add<T, platform::avx>(slice_size, src, dst);
211218#else
212- for (auto k = 0 ; k < slice_size; k++) {
213- out_data[output_start_offset + index_value * slice_size + k] +=
214- input_data[input_start_offset + j * slice_size + k];
215- }
219+ index_select_add<T>(slice_size, src, dst);
220+ #endif
216221#endif
217222 }
218223 }
0 commit comments