Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,12 @@ void llama_context::sched_reserve() {

if (cparams.fused_gdn_ch) {
// more than one token in the batch per sequence in order to take the chunked path
auto * gf = graph_reserve(16*n_seqs, n_seqs, n_outputs, mctx.get(), true);
// note: n_outputs must match n_tokens for embedding models with mean/rank pooling,
// because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies
// it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens,
// the ggml_mul_mat assertion fails. this matches the pp reservation below (line ~553).
const uint32_t n_tokens_ch = 16*n_seqs;
auto * gf = graph_reserve(n_tokens_ch, n_seqs, n_tokens_ch, mctx.get(), true);
if (!gf) {
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)");
}
Expand Down
34 changes: 34 additions & 0 deletions tools/server/tests/unit/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,40 @@ def test_embedding_mixed_input(input, is_multi_prompt: bool):
assert len(data[0]['embedding']) > 1


def test_embedding_pooling_mean():
global server
server.pooling = 'mean'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": "I believe the meaning of life is",
})
assert res.status_code == 200
assert len(res.body['data']) == 1
assert 'embedding' in res.body['data'][0]
assert len(res.body['data'][0]['embedding']) > 1

# make sure embedding vector is normalized
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON


def test_embedding_pooling_mean_multiple():
global server
server.pooling = 'mean'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"Write a joke about AI",
"This is a test",
],
})
assert res.status_code == 200
assert len(res.body['data']) == 3
for d in res.body['data']:
assert 'embedding' in d
assert len(d['embedding']) > 1


def test_embedding_pooling_none():
global server
server.pooling = 'none'
Expand Down
Loading