@@ -1350,7 +1350,7 @@ struct llama_server_context
13501350 queue_results.send (res);
13511351 }
13521352
1353- void send_embedding (llama_client_slot &slot)
1353+ void send_embedding (llama_client_slot &slot, const llama_batch & batch )
13541354 {
13551355 task_result res;
13561356 res.id = slot.task_id ;
@@ -1372,10 +1372,38 @@ struct llama_server_context
13721372 else
13731373 {
13741374 const float *data = llama_get_embeddings (ctx);
1375- std::vector<float > embedding (data, data + n_embd);
1375+ std::vector<float > embd_res (n_embd, 0 .0f );
1376+ std::vector<std::vector<float >> embedding;
1377+ for (int i = 0 ; i < batch.n_tokens ; ++i) {
1378+ if (!batch.logits [i] || batch.seq_id [i][0 ] != slot.id ) {
1379+ continue ;
1380+ }
1381+
1382+ const float * embd = llama_get_embeddings_seq (ctx, batch.seq_id [i][0 ]);
1383+ if (embd == NULL ) {
1384+ embd = llama_get_embeddings_ith (ctx, i);
1385+ }
1386+
1387+ if (embd == NULL ) {
1388+ LOG (" failed to get embeddings" );
1389+
1390+ continue ;
1391+ }
1392+
1393+ // normalize only when there is pooling
1394+ // TODO: configurable
1395+ if (llama_pooling_type (ctx) != LLAMA_POOLING_TYPE_NONE) {
1396+ common_embd_normalize (embd, embd_res.data (), n_embd, 2 );
1397+ embedding.push_back (embd_res);
1398+ } else {
1399+ embedding.push_back ({ embd, embd + n_embd });
1400+ }
1401+ }
1402+
1403+ // OAI compat
13761404 res.result_json = json
13771405 {
1378- {" embedding" , embedding },
1406+ {" embedding" , embedding[ 0 ] },
13791407 };
13801408 }
13811409 queue_results.send (res);
@@ -1996,7 +2024,7 @@ struct llama_server_context
19962024 // prompt evaluated for embedding
19972025 if (slot.embedding )
19982026 {
1999- send_embedding (slot);
2027+ send_embedding (slot, batch_view );
20002028 slot.release ();
20012029 slot.i_batch = -1 ;
20022030 continue ;
0 commit comments