Skip to content

Commit d121f02

Browse files
committed
optimization of index_select op backward
1 parent d2f9aa8 commit d121f02

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

paddle/fluid/operators/index_select_op.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,17 @@ void IndexSelectAdd<double, platform::avx>(const int n, const double* src,
175175
IndexSelectAdd<double, platform::isa_any>(n, src, dst);
176176
#endif
177177
}
178+
179+
template <>
180+
void IndexSelectAdd<int64_t, platform::avx>(const int n, const int64_t* src,
181+
int64_t* dst) {
182+
IndexSelectAdd<int64_t, platform::isa_any>(n, src, dst);
183+
}
184+
185+
template <>
186+
void IndexSelectAdd<int, platform::avx>(const int n, const int* src, int* dst) {
187+
IndexSelectAdd<int, platform::isa_any>(n, src, dst);
188+
}
178189
#endif
179190

180191
template <typename T, typename IndexT = int>
@@ -214,9 +225,9 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
214225

215226
#if ((!defined __NVCC__) && (!defined __HIPCC__))
216227
#ifdef __AVX__
217-
index_select_add<T, platform::avx>(slice_size, src, dst);
228+
IndexSelectAdd<T, platform::avx>(slice_size, src, dst);
218229
#else
219-
index_select_add<T>(slice_size, src, dst);
230+
IndexSelectAdd<T>(slice_size, src, dst);
220231
#endif
221232
#endif
222233
}

0 commit comments

Comments
 (0)