@@ -175,6 +175,17 @@ void IndexSelectAdd<double, platform::avx>(const int n, const double* src,
175175 IndexSelectAdd<double , platform::isa_any>(n, src, dst);
176176#endif
177177}
178+
179+ template <>
180+ void IndexSelectAdd<int64_t , platform::avx>(const int n, const int64_t * src,
181+ int64_t * dst) {
182+ IndexSelectAdd<int64_t , platform::isa_any>(n, src, dst);
183+ }
184+
185+ template <>
186+ void IndexSelectAdd<int , platform::avx>(const int n, const int * src, int * dst) {
187+ IndexSelectAdd<int , platform::isa_any>(n, src, dst);
188+ }
178189#endif
179190
180191template <typename T, typename IndexT = int >
@@ -214,9 +225,9 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
214225
215226#if ((!defined __NVCC__) && (!defined __HIPCC__))
216227#ifdef __AVX__
217- index_select_add <T, platform::avx>(slice_size, src, dst);
228+ IndexSelectAdd <T, platform::avx>(slice_size, src, dst);
218229#else
219- index_select_add <T>(slice_size, src, dst);
230+ IndexSelectAdd <T>(slice_size, src, dst);
220231#endif
221232#endif
222233 }
0 commit comments