@@ -8154,6 +8154,18 @@ void ggml_compute_forward_timestep_embedding(
81548154
81558155// ggml_compute_forward_argsort
81568156
8157+ template <enum ggml_sort_order order>
8158+ struct argsort_cmp {
8159+ const float * data;
8160+ bool operator ()(int32_t a, int32_t b) const {
8161+ if constexpr (order == GGML_SORT_ORDER_ASC) {
8162+ return data[a] < data[b];
8163+ } else {
8164+ return data[a] > data[b];
8165+ }
8166+ }
8167+ };
8168+
81578169static void ggml_compute_forward_argsort_f32 (
81588170 const ggml_compute_params * params,
81598171 ggml_tensor * dst) {
@@ -8180,16 +8192,18 @@ static void ggml_compute_forward_argsort_f32(
81808192 dst_data[j] = j;
81818193 }
81828194
8183- std::function<bool (int32_t , int32_t )> cmp;
8184-
8185- // note: this might be causing memory allocations? ideally should be avoided if it's the case
81868195 switch (order) {
8187- case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break ;
8188- case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break ;
8189- default : GGML_ABORT (" invalid sort order" );
8190- }
8196+ case GGML_SORT_ORDER_ASC:
8197+ std::sort (dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
8198+ break ;
81918199
8192- std::sort (dst_data, dst_data + ne0, cmp);
8200+ case GGML_SORT_ORDER_DESC:
8201+ std::sort (dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
8202+ break ;
8203+
8204+ default :
8205+ GGML_ABORT (" invalid sort order" );
8206+ }
81938207 }
81948208}
81958209
0 commit comments