Skip to content

Commit 68fc014

Browse files
authored
feat(vllm): add support for embeddings (#3440)
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 56db715 commit 68fc014

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

backend/python/vllm/backend.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,26 @@ async def Predict(self, request, context):
135135
res = await gen.__anext__()
136136
return res
137137

138+
def Embedding(self, request, context):
139+
"""
140+
A gRPC method that calculates embeddings for a given sentence.
141+
142+
Args:
143+
request: An EmbeddingRequest object that contains the request parameters.
144+
context: A grpc.ServicerContext object that provides information about the RPC.
145+
146+
Returns:
147+
An EmbeddingResult object that contains the calculated embeddings.
148+
"""
149+
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
150+
outputs = self.model.encode(request.Embeddings)
151+
# Check if we have one result at least
152+
if len(outputs) == 0:
153+
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
154+
context.set_details("No embeddings were calculated.")
155+
return backend_pb2.EmbeddingResult()
156+
return backend_pb2.EmbeddingResult(embeddings=outputs[0].outputs.embedding)
157+
138158
async def PredictStream(self, request, context):
139159
"""
140160
Generates text based on the given prompt and sampling parameters, and streams the results.

backend/python/vllm/test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,5 +72,28 @@ def test_text(self):
7272
except Exception as err:
7373
print(err)
7474
self.fail("text service failed")
75+
finally:
76+
self.tearDown()
77+
78+
def test_embedding(self):
79+
"""
80+
This method tests if the embeddings are generated successfully
81+
"""
82+
try:
83+
self.setUp()
84+
with grpc.insecure_channel("localhost:50051") as channel:
85+
stub = backend_pb2_grpc.BackendStub(channel)
86+
response = stub.LoadModel(backend_pb2.ModelOptions(Model="intfloat/e5-mistral-7b-instruct"))
87+
self.assertTrue(response.success)
88+
embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.")
89+
embedding_response = stub.Embedding(embedding_request)
90+
self.assertIsNotNone(embedding_response.embeddings)
91+
# assert that is a list of floats
92+
self.assertIsInstance(embedding_response.embeddings, list)
93+
# assert that the list is not empty
94+
self.assertTrue(len(embedding_response.embeddings) > 0)
95+
except Exception as err:
96+
print(err)
97+
self.fail("Embedding service failed")
7598
finally:
7699
self.tearDown()

0 commit comments

Comments
 (0)