Skip to content

Commit 1c3c975

Browse files
authored
[FEATURE] Enables /score endpoint for embedding models (#12846)
1 parent 1cdc886 commit 1c3c975

File tree

11 files changed

+599
-522
lines changed

11 files changed

+599
-522
lines changed

docs/source/models/pooling_models.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ A code example can be found here: <gh-file:examples/offline_inference/basic/clas
108108
### `LLM.score`
109109

110110
The {class}`~vllm.LLM.score` method outputs similarity scores between sentence pairs.
111-
It is primarily designed for [cross-encoder models](https://www.sbert.net/examples/applications/cross-encoder/README.html).
112-
These types of models serve as rerankers between candidate query-document pairs in RAG systems.
111+
It is designed for embedding models and cross encoder models. Embedding models use cosine similarity, and [cross-encoder models](https://www.sbert.net/examples/applications/cross-encoder/README.html) serve as rerankers between candidate query-document pairs in RAG systems.
113112

114113
:::{note}
115114
vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG.

docs/source/serving/openai_compatible_server.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ In addition, we have the following custom APIs:
5151
- [Pooling API](#pooling-api) (`/pooling`)
5252
- Applicable to all [pooling models](../models/pooling_models.md).
5353
- [Score API](#score-api) (`/score`)
54-
- Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`).
54+
- Applicable to embedding models and [cross-encoder models](../models/pooling_models.md) (`--task score`).
5555
- [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`)
5656
- Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/)
5757
- Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank)
@@ -333,10 +333,10 @@ Code example: <gh-file:examples/online_serving/openai_pooling_client.py>
333333

334334
### Score API
335335

336-
Our Score API applies a cross-encoder model to predict scores for sentence pairs.
336+
Our Score API can apply a cross-encoder model or an embedding model to predict scores for sentence pairs. When using an embedding model the score corresponds to the cosine similarity between each embedding pair.
337337
Usually, the score for a sentence pair refers to the similarity between two sentences, on a scale of 0 to 1.
338338

339-
You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
339+
You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
340340

341341
Code example: <gh-file:examples/online_serving/openai_cross_encoder_score.py>
342342

@@ -496,11 +496,11 @@ The following extra parameters are supported:
496496

497497
### Re-rank API
498498

499-
Our Re-rank API applies a cross-encoder model to predict relevant scores between a single query, and
499+
Our Re-rank API can apply an embedding model or a cross-encoder model to predict relevant scores between a single query, and
500500
each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences, on
501501
a scale of 0 to 1.
502502

503-
You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
503+
You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
504504

505505
The rerank endpoints support popular re-rank models such as `BAAI/bge-reranker-base` and other models supporting the
506506
`score` task. Additionally, `/rerank`, `/v1/rerank`, and `/v2/rerank`

tests/entrypoints/openai/test_rerank.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@
88
from ...utils import RemoteOpenAIServer
99

1010
MODEL_NAME = "BAAI/bge-reranker-base"
11+
DTYPE = "bfloat16"
1112

1213

1314
@pytest.fixture(scope="module")
1415
def server():
15-
args = ["--enforce-eager", "--max-model-len", "100"]
16+
args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
1617

1718
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
1819
yield remote_server
1920

2021

21-
@pytest.mark.asyncio
2222
@pytest.mark.parametrize("model_name", [MODEL_NAME])
2323
def test_rerank_texts(server: RemoteOpenAIServer, model_name: str):
2424
query = "What is the capital of France?"
@@ -42,7 +42,6 @@ def test_rerank_texts(server: RemoteOpenAIServer, model_name: str):
4242
assert rerank.results[1].relevance_score <= 0.01
4343

4444

45-
@pytest.mark.asyncio
4645
@pytest.mark.parametrize("model_name", [MODEL_NAME])
4746
def test_top_n(server: RemoteOpenAIServer, model_name: str):
4847
query = "What is the capital of France?"
@@ -68,7 +67,6 @@ def test_top_n(server: RemoteOpenAIServer, model_name: str):
6867
assert rerank.results[1].relevance_score <= 0.01
6968

7069

71-
@pytest.mark.asyncio
7270
@pytest.mark.parametrize("model_name", [MODEL_NAME])
7371
def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):
7472

Lines changed: 173 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,123 +1,185 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import math
4+
from typing import Any
5+
36
import pytest
47
import requests
8+
import torch.nn.functional as F
9+
from torch import tensor
510

611
from vllm.entrypoints.openai.protocol import ScoreResponse
712

813
from ...utils import RemoteOpenAIServer
914

10-
MODEL_NAME = "BAAI/bge-reranker-v2-m3"
11-
12-
13-
@pytest.fixture(scope="module")
14-
def server():
15-
args = ["--enforce-eager", "--max-model-len", "100"]
16-
17-
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
15+
MODELS = [
16+
{
17+
"name": "BAAI/bge-reranker-v2-m3",
18+
"is_cross_encoder": True
19+
},
20+
{
21+
"name": "BAAI/bge-base-en-v1.5",
22+
"is_cross_encoder": False
23+
},
24+
]
25+
DTYPE = "half"
26+
27+
28+
def run_transformers(hf_model, model, text_pairs):
29+
if model["is_cross_encoder"]:
30+
return hf_model.predict(text_pairs).tolist()
31+
else:
32+
hf_embeddings = [
33+
hf_model.encode(text_pair) for text_pair in text_pairs
34+
]
35+
return [
36+
F.cosine_similarity(tensor(pair[0]), tensor(pair[1]), dim=0)
37+
for pair in hf_embeddings
38+
]
39+
40+
41+
@pytest.fixture(scope="class", params=MODELS)
42+
def model(request):
43+
yield request.param
44+
45+
46+
@pytest.fixture(scope="class")
47+
def server(model: dict[str, Any]):
48+
args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
49+
50+
with RemoteOpenAIServer(model["name"], args) as remote_server:
1851
yield remote_server
1952

2053

21-
@pytest.mark.asyncio
22-
@pytest.mark.parametrize("model_name", [MODEL_NAME])
23-
def test_text_1_str_text_2_list(server: RemoteOpenAIServer, model_name: str):
24-
text_1 = "What is the capital of France?"
25-
text_2 = [
26-
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
27-
]
28-
29-
score_response = requests.post(server.url_for("score"),
30-
json={
31-
"model": model_name,
32-
"text_1": text_1,
33-
"text_2": text_2,
34-
})
35-
score_response.raise_for_status()
36-
score = ScoreResponse.model_validate(score_response.json())
37-
38-
assert score.id is not None
39-
assert score.data is not None
40-
assert len(score.data) == 2
41-
assert score.data[0].score <= 0.01
42-
assert score.data[1].score >= 0.9
43-
44-
45-
@pytest.mark.asyncio
46-
@pytest.mark.parametrize("model_name", [MODEL_NAME])
47-
def test_text_1_list_text_2_list(server: RemoteOpenAIServer, model_name: str):
48-
text_1 = [
49-
"What is the capital of the United States?",
50-
"What is the capital of France?"
51-
]
52-
text_2 = [
53-
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
54-
]
55-
56-
score_response = requests.post(server.url_for("score"),
57-
json={
58-
"model": model_name,
59-
"text_1": text_1,
60-
"text_2": text_2,
61-
})
62-
score_response.raise_for_status()
63-
score = ScoreResponse.model_validate(score_response.json())
64-
65-
assert score.id is not None
66-
assert score.data is not None
67-
assert len(score.data) == 2
68-
assert score.data[0].score <= 0.01
69-
assert score.data[1].score >= 0.9
70-
71-
72-
@pytest.mark.asyncio
73-
@pytest.mark.parametrize("model_name", [MODEL_NAME])
74-
def test_text_1_str_text_2_str(server: RemoteOpenAIServer, model_name: str):
75-
text_1 = "What is the capital of France?"
76-
text_2 = "The capital of France is Paris."
77-
78-
score_response = requests.post(server.url_for("score"),
79-
json={
80-
"model": model_name,
81-
"text_1": text_1,
82-
"text_2": text_2,
83-
})
84-
score_response.raise_for_status()
85-
score = ScoreResponse.model_validate(score_response.json())
86-
87-
assert score.id is not None
88-
assert score.data is not None
89-
assert len(score.data) == 1
90-
assert score.data[0].score >= 0.9
91-
92-
93-
@pytest.mark.asyncio
94-
@pytest.mark.parametrize("model_name", [MODEL_NAME])
95-
def test_score_max_model_len(server: RemoteOpenAIServer, model_name: str):
96-
97-
text_1 = "What is the capital of France?" * 20
98-
text_2 = [
99-
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
100-
]
101-
102-
score_response = requests.post(server.url_for("score"),
103-
json={
104-
"model": model_name,
105-
"text_1": text_1,
106-
"text_2": text_2,
107-
})
108-
assert score_response.status_code == 400
109-
# Assert just a small fragments of the response
110-
assert "Please reduce the length of the input." in \
111-
score_response.text
112-
113-
# Test truncation
114-
score_response = requests.post(server.url_for("score"),
115-
json={
116-
"model": model_name,
117-
"text_1": text_1,
118-
"text_2": text_2,
119-
"truncate_prompt_tokens": 101
120-
})
121-
assert score_response.status_code == 400
122-
assert "Please, select a smaller truncation size." in \
123-
score_response.text
54+
@pytest.fixture(scope="class")
55+
def runner(model: dict[str, Any], hf_runner):
56+
kwargs = {
57+
"dtype": DTYPE,
58+
"is_cross_encoder" if model["is_cross_encoder"]\
59+
else "is_sentence_transformer": True
60+
}
61+
62+
with hf_runner(model["name"], **kwargs) as hf_model:
63+
yield hf_model
64+
65+
66+
class TestModel:
67+
68+
def test_text_1_str_text_2_list(self, server: RemoteOpenAIServer,
69+
model: dict[str, Any], runner):
70+
text_1 = "What is the capital of France?"
71+
text_2 = [
72+
"The capital of Brazil is Brasilia.",
73+
"The capital of France is Paris."
74+
]
75+
76+
score_response = requests.post(server.url_for("score"),
77+
json={
78+
"model": model["name"],
79+
"text_1": text_1,
80+
"text_2": text_2,
81+
})
82+
score_response.raise_for_status()
83+
score = ScoreResponse.model_validate(score_response.json())
84+
85+
assert score.id is not None
86+
assert score.data is not None
87+
assert len(score.data) == 2
88+
89+
vllm_outputs = [d.score for d in score.data]
90+
91+
text_pairs = [[text_1, text_2[0]], [text_1, text_2[1]]]
92+
hf_outputs = run_transformers(runner, model, text_pairs)
93+
94+
for i in range(len(vllm_outputs)):
95+
assert math.isclose(hf_outputs[i], vllm_outputs[i], rel_tol=0.01)
96+
97+
def test_text_1_list_text_2_list(self, server: RemoteOpenAIServer,
98+
model: dict[str, Any], runner):
99+
text_1 = [
100+
"What is the capital of the United States?",
101+
"What is the capital of France?"
102+
]
103+
text_2 = [
104+
"The capital of Brazil is Brasilia.",
105+
"The capital of France is Paris."
106+
]
107+
108+
score_response = requests.post(server.url_for("score"),
109+
json={
110+
"model": model["name"],
111+
"text_1": text_1,
112+
"text_2": text_2,
113+
})
114+
score_response.raise_for_status()
115+
score = ScoreResponse.model_validate(score_response.json())
116+
117+
assert score.id is not None
118+
assert score.data is not None
119+
assert len(score.data) == 2
120+
121+
vllm_outputs = [d.score for d in score.data]
122+
123+
text_pairs = [[text_1[0], text_2[0]], [text_1[1], text_2[1]]]
124+
hf_outputs = run_transformers(runner, model, text_pairs)
125+
126+
for i in range(len(vllm_outputs)):
127+
assert math.isclose(hf_outputs[i], vllm_outputs[i], rel_tol=0.01)
128+
129+
def test_text_1_str_text_2_str(self, server: RemoteOpenAIServer,
130+
model: dict[str, Any], runner):
131+
text_1 = "What is the capital of France?"
132+
text_2 = "The capital of France is Paris."
133+
134+
score_response = requests.post(server.url_for("score"),
135+
json={
136+
"model": model["name"],
137+
"text_1": text_1,
138+
"text_2": text_2,
139+
})
140+
score_response.raise_for_status()
141+
score = ScoreResponse.model_validate(score_response.json())
142+
143+
assert score.id is not None
144+
assert score.data is not None
145+
assert len(score.data) == 1
146+
147+
vllm_outputs = [d.score for d in score.data]
148+
149+
text_pairs = [[text_1, text_2]]
150+
hf_outputs = run_transformers(runner, model, text_pairs)
151+
152+
for i in range(len(vllm_outputs)):
153+
assert math.isclose(hf_outputs[i], vllm_outputs[i], rel_tol=0.01)
154+
155+
def test_score_max_model_len(self, server: RemoteOpenAIServer,
156+
model: dict[str, Any]):
157+
158+
text_1 = "What is the capital of France?" * 20
159+
text_2 = [
160+
"The capital of Brazil is Brasilia.",
161+
"The capital of France is Paris."
162+
]
163+
164+
score_response = requests.post(server.url_for("score"),
165+
json={
166+
"model": model["name"],
167+
"text_1": text_1,
168+
"text_2": text_2,
169+
})
170+
assert score_response.status_code == 400
171+
# Assert just a small fragments of the response
172+
assert "Please reduce the length of the input." in \
173+
score_response.text
174+
175+
# Test truncation
176+
score_response = requests.post(server.url_for("score"),
177+
json={
178+
"model": model["name"],
179+
"text_1": text_1,
180+
"text_2": text_2,
181+
"truncate_prompt_tokens": 101
182+
})
183+
assert score_response.status_code == 400
184+
assert "Please, select a smaller truncation size." in \
185+
score_response.text

0 commit comments

Comments
 (0)