Skip to content

Commit 5b54495

Browse files
committed
Puts embedding score code in score_utils to avoid duplicated code
Signed-off-by: Gabriel Marinho <[email protected]>
1 parent 6218cc3 commit 5b54495

File tree

3 files changed

+56
-42
lines changed

3 files changed

+56
-42
lines changed

vllm/entrypoints/llm.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
Tuple, Type, Union, cast, overload)
88

99
import cloudpickle
10-
import torch
1110
import torch.nn as nn
1211
from tqdm import tqdm
1312
from typing_extensions import TypeVar, deprecated
@@ -25,6 +24,7 @@
2524
apply_mistral_chat_template,
2625
parse_chat_messages,
2726
resolve_chat_template_content_format)
27+
from vllm.entrypoints.score_utils import _cosine_similarity
2828
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
2929
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
3030
from vllm.logger import init_logger
@@ -1010,40 +1010,25 @@ def _embedding_score(
10101010
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
10111011
) -> List[ScoringRequestOutput]:
10121012

1013-
encoded_output = self.encode(
1013+
encoded_output: List[PoolingRequestOutput] = self.encode(
10141014
text_1 + text_2,
10151015
use_tqdm=use_tqdm,
10161016
lora_request=lora_request,
10171017
prompt_adapter_request=prompt_adapter_request)
1018-
encoded_output_1 = encoded_output[0:len(text_1)]
1019-
encoded_output_2 = encoded_output[len(text_1):]
1018+
1019+
encoded_output_1: List[PoolingRequestOutput] = encoded_output[
1020+
0:len(text_1)]
1021+
encoded_output_2: List[PoolingRequestOutput] = encoded_output[
1022+
len(text_1):]
10201023

10211024
if len(encoded_output_1) == 1:
10221025
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
10231026

1024-
output_pairs = [(t1, t2)
1025-
for t1, t2 in zip(encoded_output_1, encoded_output_2)]
1026-
1027-
scores = []
1028-
scorer = torch.nn.CosineSimilarity(0)
1027+
scores: List[PoolingRequestOutput] = []
10291028

1030-
for embed_1, embed_2 in output_pairs:
1031-
pair_score = scorer(embed_1.outputs.data, embed_2.outputs.data)
1032-
1033-
if (pad_token_id := getattr(tokenizer, "pad_token_id",
1034-
None)) is not None:
1035-
tokens = embed_1.prompt_token_ids + [
1036-
pad_token_id
1037-
] + embed_2.prompt_token_ids
1038-
else:
1039-
tokens = embed_1.prompt_token_ids + embed_2.prompt_token_ids
1040-
1041-
scores.append(
1042-
PoolingRequestOutput(
1043-
request_id=f"{embed_1.request_id}_{embed_2.request_id}",
1044-
outputs=pair_score,
1045-
prompt_token_ids=tokens,
1046-
finished=True))
1029+
scores = _cosine_similarity(tokenizer=tokenizer,
1030+
embed_1=encoded_output_1,
1031+
embed_2=encoded_output_2)
10471032

10481033
items = self.engine_class.validate_outputs(scores,
10491034
PoolingRequestOutput)

vllm/entrypoints/openai/serving_score.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import time
44
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Union
55

6-
import torch
76
from fastapi import Request
87

98
from vllm.config import ModelConfig
@@ -16,6 +15,7 @@
1615
ScoreResponseData, UsageInfo)
1716
from vllm.entrypoints.openai.serving_engine import OpenAIServing
1817
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
18+
from vllm.entrypoints.score_utils import _cosine_similarity
1919
from vllm.inputs.data import TokensPrompt
2020
from vllm.logger import init_logger
2121
from vllm.lora.request import LoRARequest
@@ -121,26 +121,18 @@ async def _embedding_score(
121121
if len(emb_text_1) == 1:
122122
emb_text_1 = emb_text_1 * len(emb_text_2)
123123

124-
scorer = torch.nn.CosineSimilarity(0)
124+
embeddings_1: List[PoolingRequestOutput] = []
125+
embeddings_2: List[PoolingRequestOutput] = []
125126

126127
for emb_1, emb_2 in zip(emb_text_1, emb_text_2):
127128
assert emb_1 is not None
128129
assert emb_2 is not None
129-
pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)
130+
embeddings_1.append(emb_1)
131+
embeddings_2.append(emb_2)
130132

131-
padding = []
132-
if (pad_token_id := getattr(tokenizer, "pad_token_id",
133-
None)) is not None:
134-
padding = [pad_token_id]
135-
136-
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
137-
138-
final_res_batch.append(
139-
PoolingRequestOutput(
140-
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
141-
outputs=pair_score,
142-
prompt_token_ids=tokens,
143-
finished=True))
133+
final_res_batch = _cosine_similarity(tokenizer=tokenizer,
134+
embed_1=embeddings_1,
135+
embed_2=embeddings_2)
144136

145137
return final_res_batch
146138

vllm/entrypoints/score_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from typing import List, Union
3+
4+
from torch.nn import CosineSimilarity
5+
6+
from vllm.outputs import PoolingRequestOutput
7+
from vllm.transformers_utils.tokenizer import (PreTrainedTokenizer,
8+
PreTrainedTokenizerFast)
9+
10+
11+
def _cosine_similarity(
12+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
13+
embed_1: List[PoolingRequestOutput],
14+
embed_2: List[PoolingRequestOutput],
15+
) -> List[PoolingRequestOutput]:
16+
17+
scorer = CosineSimilarity(0)
18+
scores: Union[List[PoolingRequestOutput]] = []
19+
20+
for emb_1, emb_2 in zip(embed_1, embed_2):
21+
pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)
22+
23+
padding = []
24+
if (pad_token_id := getattr(tokenizer, "pad_token_id",
25+
None)) is not None:
26+
padding = [pad_token_id]
27+
28+
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
29+
30+
scores.append(
31+
PoolingRequestOutput(
32+
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
33+
outputs=pair_score,
34+
prompt_token_ids=tokens,
35+
finished=True))
36+
37+
return scores

0 commit comments

Comments
 (0)