Skip to content

Commit 463742b

Browse files
slarenshaobo.xie
authored andcommitted
ggml-cpu : use template for argsort (#17222)
1 parent 5b18132 commit 463742b

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8154,6 +8154,18 @@ void ggml_compute_forward_timestep_embedding(
81548154

81558155
// ggml_compute_forward_argsort
81568156

8157+
template<enum ggml_sort_order order>
8158+
struct argsort_cmp {
8159+
const float * data;
8160+
bool operator()(int32_t a, int32_t b) const {
8161+
if constexpr (order == GGML_SORT_ORDER_ASC) {
8162+
return data[a] < data[b];
8163+
} else {
8164+
return data[a] > data[b];
8165+
}
8166+
}
8167+
};
8168+
81578169
static void ggml_compute_forward_argsort_f32(
81588170
const ggml_compute_params * params,
81598171
ggml_tensor * dst) {
@@ -8180,16 +8192,18 @@ static void ggml_compute_forward_argsort_f32(
81808192
dst_data[j] = j;
81818193
}
81828194

8183-
std::function<bool(int32_t, int32_t)> cmp;
8184-
8185-
// note: this might be causing memory allocations? ideally should be avoided if it's the case
81868195
switch (order) {
8187-
case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break;
8188-
case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break;
8189-
default: GGML_ABORT("invalid sort order");
8190-
}
8196+
case GGML_SORT_ORDER_ASC:
8197+
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
8198+
break;
81918199

8192-
std::sort(dst_data, dst_data + ne0, cmp);
8200+
case GGML_SORT_ORDER_DESC:
8201+
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
8202+
break;
8203+
8204+
default:
8205+
GGML_ABORT("invalid sort order");
8206+
}
81938207
}
81948208
}
81958209

tests/test-backend-ops.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8312,6 +8312,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
83128312
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, it));
83138313
}
83148314

8315+
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
8316+
83158317
return test_cases;
83168318
}
83178319

0 commit comments

Comments
 (0)