Skip to content

Commit d2f9aa8

Browse files
committed
optimization of index_select op backward
1 parent 9beb43b commit d2f9aa8

File tree

1 file changed

+27
-22
lines changed

1 file changed

+27
-22
lines changed

paddle/fluid/operators/index_select_op.h

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,24 @@ class IndexSelectKernel : public framework::OpKernel<T> {
126126
};
127127

128128
#if ((!defined __NVCC__) && (!defined __HIPCC__))
129-
template <typename T>
130-
void index_sum(const size_t n, const T* src, T* dst) {
129+
template <typename T, platform::cpu_isa_t isa = platform::isa_any>
130+
void IndexSelectAdd(const int n, const T* src, T* dst) {
131+
for (auto k = 0; k < n; ++k) {
132+
dst[k] += src[k];
133+
}
134+
}
135+
136+
// description: Index addition uses intel intrinsic instruction set to read and
137+
// write data in parallel
138+
template <>
139+
void IndexSelectAdd<float, platform::avx>(const int n, const float* src,
140+
float* dst) {
131141
#ifdef __AVX__
132142
constexpr int block = YMM_FLOAT_BLOCK;
133-
unsigned int i, end;
134-
i = end = 0;
135-
end = n & ~(block - 1);
143+
int i = 0;
144+
int end = n & ~(block - 1);
136145
for (i = 0; i < end; i += block) {
146+
// Quote from https://software.intel.com/sites/landingpage/IntrinsicsGuide/
137147
_mm256_storeu_ps(reinterpret_cast<float*>(dst) + i,
138148
_mm256_add_ps(_mm256_loadu_ps((const float*)dst + i),
139149
_mm256_loadu_ps((const float*)src + i)));
@@ -142,19 +152,17 @@ void index_sum(const size_t n, const T* src, T* dst) {
142152
dst[i] += src[i];
143153
}
144154
#else
145-
for (size_t k = 0; k < n; k++) {
146-
dst[k] += src[k];
147-
}
155+
IndexSelectAdd<float, platform::isa_any>(n, src, dst);
148156
#endif
149157
}
150158

151159
template <>
152-
void index_sum(const size_t n, const double* src, double* dst) {
160+
void IndexSelectAdd<double, platform::avx>(const int n, const double* src,
161+
double* dst) {
153162
#ifdef __AVX__
154163
constexpr int block = XMM_FLOAT_BLOCK;
155-
unsigned int i, end;
156-
i = end = 0;
157-
end = n & ~(block - 1);
164+
int i = 0;
165+
int end = n & ~(block - 1);
158166
for (i = 0; i < end; i += block) {
159167
_mm256_storeu_pd(reinterpret_cast<double*>(dst) + i,
160168
_mm256_add_pd(_mm256_loadu_pd((const double*)dst + i),
@@ -164,9 +172,7 @@ void index_sum(const size_t n, const double* src, double* dst) {
164172
dst[i] += src[i];
165173
}
166174
#else
167-
for (size_t k = 0; k < n; k++) {
168-
dst[k] += src[k];
169-
}
175+
IndexSelectAdd<double, platform::isa_any>(n, src, dst);
170176
#endif
171177
}
172178
#endif
@@ -182,6 +188,7 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
182188
auto input_dim_size = input_dim.size();
183189
auto output_dim = x_grad->dims();
184190
std::memset(out_data, 0.0, x_grad->numel() * sizeof(T));
191+
185192
auto slice_size = 1;
186193
for (auto i = dim + 1; i < input_dim_size; i++) {
187194
slice_size *= input_dim[i];
@@ -202,17 +209,15 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
202209
auto output_start_offset = i * output_width;
203210
for (auto j = 0; j < index_size; j++) {
204211
IndexT index_value = index_data[j];
205-
#ifdef __AVX__
206212
auto src = input_data + input_start_offset + j * slice_size;
207213
auto dst = out_data + output_start_offset + index_value * slice_size;
214+
208215
#if ((!defined __NVCC__) && (!defined __HIPCC__))
209-
index_sum(slice_size, src, dst);
210-
#endif
216+
#ifdef __AVX__
217+
index_select_add<T, platform::avx>(slice_size, src, dst);
211218
#else
212-
for (auto k = 0; k < slice_size; k++) {
213-
out_data[output_start_offset + index_value * slice_size + k] +=
214-
input_data[input_start_offset + j * slice_size + k];
215-
}
219+
index_select_add<T>(slice_size, src, dst);
220+
#endif
216221
#endif
217222
}
218223
}

0 commit comments

Comments
 (0)