Skip to content

Commit 7fa7c81

Browse files
szymondudyczManul from Pathway
authored andcommitted
Batch OpenAI embedder (#9381)
GitOrigin-RevId: 4dd920c04055365d0992802b600e48c2788535b1
1 parent 6c97d03 commit 7fa7c81

File tree

2 files changed

+173
-33
lines changed

2 files changed

+173
-33
lines changed

integration_tests/xpack/test_embedders.py

Lines changed: 84 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
"model", [None, "text-embedding-ada-002", "text-embedding-3-small"]
1616
)
1717
@pytest.mark.parametrize("strategy", ["start", "end"])
18-
def test_openai_embedder(text: str, model: str, strategy: str):
18+
def test_openai_embedder(text: str, model: str | None, strategy: str):
19+
table = pw.debug.table_from_rows(
20+
schema=pw.schema_from_types(text=str), rows=[(text,)]
21+
)
1922
if model is None:
2023
embedder = embedders.OpenAIEmbedder(
21-
truncation_keep_strategy=strategy,
24+
truncation_keep_strategy=strategy, # type: ignore
2225
retry_strategy=pw.udfs.ExponentialBackoffRetryStrategy(),
2326
)
2427
else:
@@ -28,28 +31,102 @@ def test_openai_embedder(text: str, model: str, strategy: str):
2831
retry_strategy=pw.udfs.ExponentialBackoffRetryStrategy(),
2932
)
3033

31-
sync_embedder = _coerce_sync(embedder.func)
34+
table = table.select(embedding=embedder(pw.this.text))
3235

33-
embedding = sync_embedder(text)
36+
result = pw.debug.table_to_pandas(table).to_dict("records")
3437

35-
assert len(embedding) > 1500
38+
assert len(result) == 1
39+
assert isinstance(result[0]["embedding"][0], float)
40+
assert len(result[0]["embedding"]) > 1500
3641

3742

3843
@pytest.mark.parametrize("model", ["text-embedding-ada-002", "text-embedding-3-small"])
3944
def test_openai_embedder_fails_no_truncation(model: str):
4045
truncation_keep_strategy = None
4146
embedder = embedders.OpenAIEmbedder(
42-
model=model, truncation_keep_strategy=truncation_keep_strategy
47+
model=model,
48+
truncation_keep_strategy=truncation_keep_strategy,
49+
retry_strategy=pw.udfs.ExponentialBackoffRetryStrategy(),
4350
)
4451

4552
sync_embedder = _coerce_sync(embedder.func)
4653

4754
with pytest.raises(Exception) as exc:
48-
sync_embedder(LONG_TEXT)
55+
sync_embedder([LONG_TEXT])
4956

5057
assert "maximum context length" in str(exc)
5158

5259

60+
def test_openai_embedder_with_common_parameter():
61+
table = pw.debug.table_from_rows(
62+
schema=pw.schema_from_types(text=str), rows=[("aaa",), ("bbb",)]
63+
)
64+
65+
embedder = embedders.OpenAIEmbedder(
66+
model="text-embedding-3-small",
67+
retry_strategy=pw.udfs.ExponentialBackoffRetryStrategy(),
68+
)
69+
70+
table = table.select(embedding=embedder(pw.this.text, dimensions=700))
71+
72+
result = pw.debug.table_to_pandas(table).to_dict("records")
73+
74+
assert len(result) == 2
75+
assert isinstance(result[0]["embedding"][0], float)
76+
assert len(result[0]["embedding"]) == 700
77+
assert isinstance(result[1]["embedding"][0], float)
78+
assert len(result[1]["embedding"]) == 700
79+
80+
81+
def test_openai_embedder_with_different_parameter():
82+
table = pw.debug.table_from_rows(
83+
schema=pw.schema_from_types(text=str, dimensions=int),
84+
rows=[("aaa", 300), ("bbb", 800)],
85+
)
86+
87+
embedder = embedders.OpenAIEmbedder(
88+
model="text-embedding-3-small",
89+
retry_strategy=pw.udfs.ExponentialBackoffRetryStrategy(),
90+
)
91+
92+
table = table.select(
93+
text=pw.this.text,
94+
embedding=embedder(pw.this.text, dimensions=pw.this.dimensions),
95+
)
96+
97+
result = pw.debug.table_to_pandas(table).to_dict("records")
98+
99+
assert len(result) == 2
100+
assert isinstance(result[0]["embedding"][0], float)
101+
assert isinstance(result[1]["embedding"][0], float)
102+
if result[0]["text"] == "aaa":
103+
assert len(result[0]["embedding"]) == 300
104+
else:
105+
assert len(result[1]["embedding"]) == 300
106+
if result[0]["text"] == "bbb":
107+
assert len(result[0]["embedding"]) == 800
108+
else:
109+
assert len(result[1]["embedding"]) == 800
110+
111+
112+
def test_openai_embedder_input_as_kwarg():
113+
table = pw.debug.table_from_rows(
114+
schema=pw.schema_from_types(text=str), rows=[("foo",)]
115+
)
116+
embedder = embedders.OpenAIEmbedder(
117+
model="text-embedding-3-small",
118+
retry_strategy=pw.udfs.ExponentialBackoffRetryStrategy(),
119+
)
120+
121+
table = table.select(embedding=embedder(input=pw.this.text))
122+
123+
result = pw.debug.table_to_pandas(table).to_dict("records")
124+
125+
assert len(result) == 1
126+
assert isinstance(result[0]["embedding"][0], float)
127+
assert len(result[0]["embedding"]) > 1500
128+
129+
53130
def test_sentence_transformer_embedder():
54131
table = pw.debug.table_from_rows(
55132
schema=pw.schema_from_types(text=str), rows=[("aaa",), ("bbb",)]

python/pathway/xpacks/llm/embedders.py

Lines changed: 89 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55
import asyncio
66
import logging
7-
from typing import Literal
7+
from typing import Any, Literal
88

99
import numpy as np
1010

@@ -85,6 +85,23 @@ def __call__(
8585
return super().__call__(input, *args, **kwargs)
8686

8787

88+
def _split_batched_kwargs(
89+
kwargs: dict[str, list[Any]]
90+
) -> tuple[dict[str, Any], dict[str, list[Any]]]:
91+
constant_kwargs = {}
92+
per_row_kwargs = {}
93+
94+
if kwargs:
95+
for key, values in kwargs.items():
96+
v = values[0]
97+
if all(value == v for value in values):
98+
constant_kwargs[key] = v
99+
else:
100+
per_row_kwargs[key] = values
101+
102+
return constant_kwargs, per_row_kwargs
103+
104+
88105
class OpenAIEmbedder(BaseEmbedder):
89106
"""Pathway wrapper for OpenAI Embedding services.
90107
@@ -114,6 +131,8 @@ class OpenAIEmbedder(BaseEmbedder):
114131
Can be ``"start"``, ``"end"`` or ``None``. ``"start"`` will keep the first part of the text
115132
and remove the rest. ``"end"`` will keep the last part of the text.
116133
If `None`, no truncation will be applied to any of the documents, this may cause API exceptions.
134+
batch_size: maximum size of a single batch to be sent to the embedder. Bigger
135+
batches may reduce the time needed for embedding.
117136
encoding_format: The format to return the embeddings in. Can be either `float` or
118137
`base64 <https://pypi.org/project/pybase64/>`_.
119138
user: A unique identifier representing your end-user, which can help OpenAI to monitor
@@ -160,6 +179,7 @@ def __init__(
160179
cache_strategy: udfs.CacheStrategy | None = None,
161180
model: str | None = "text-embedding-3-small",
162181
truncation_keep_strategy: Literal["start", "end"] | None = "start",
182+
batch_size: int = 128,
163183
**openai_kwargs,
164184
):
165185
with optional_imports("xpack-llm"):
@@ -168,42 +188,82 @@ def __init__(
168188
_monkeypatch_openai_async()
169189
executor = udfs.async_executor(capacity=capacity, retry_strategy=retry_strategy)
170190
super().__init__(
171-
executor=executor,
172-
cache_strategy=cache_strategy,
191+
executor=executor, cache_strategy=cache_strategy, max_batch_size=batch_size
173192
)
174193
self.truncation_keep_strategy = truncation_keep_strategy
175194
self.kwargs = dict(openai_kwargs)
176-
api_key = self.kwargs.pop("api_key", None)
177-
self.client = openai.AsyncOpenAI(api_key=api_key, max_retries=0)
195+
self.api_key = self.kwargs.pop("api_key", None)
196+
self.client: openai.AsyncOpenAI | None = None
197+
198+
# Initialization of OpenAI for the purpose of checking if api_key was provided
199+
# Actual initialization of the client is delayed to __wrapped__ to avoid issues
200+
# with the event loop.
201+
_ = openai.AsyncOpenAI(api_key=self.api_key, max_retries=0)
178202
if model is not None:
179203
self.kwargs["model"] = model
180204

181-
async def __wrapped__(self, input, **kwargs) -> np.ndarray:
205+
async def __wrapped__(self, inputs: list[str], **kwargs) -> list[np.ndarray]:
182206
"""Embed the documents
183207
184208
Args:
185-
input: mandatory, the string to embed.
209+
inputs: mandatory, the strings to embed.
186210
**kwargs: optional parameters, if unset defaults from the constructor
187211
will be taken.
188-
"""
189-
input = input or "."
212+
#"""
213+
import openai
214+
215+
if self.client is None:
216+
self.client = openai.AsyncOpenAI(api_key=self.api_key, max_retries=0)
190217

191-
kwargs = {**self.kwargs, **kwargs}
192218
kwargs = _extract_value_inside_dict(kwargs)
193219

194-
if kwargs.get("model") is None:
220+
if kwargs.get("model") is None and self.kwargs.get("model") is None:
195221
raise ValueError(
196222
"`model` parameter is missing in `OpenAIEmbedder`. "
197223
"Please provide the model name either in the constructor or in the function call."
198224
)
199225

226+
constant_kwargs, per_row_kwargs = _split_batched_kwargs(kwargs)
227+
constant_kwargs = {**self.kwargs, **constant_kwargs}
228+
200229
if self.truncation_keep_strategy:
201-
input = self.truncate_context(
202-
kwargs["model"], input, self.truncation_keep_strategy
203-
)
230+
if "model" in per_row_kwargs:
231+
inputs = [
232+
self.truncate_context(model, input, self.truncation_keep_strategy)
233+
for (model, input) in zip(per_row_kwargs["model"], inputs)
234+
]
235+
else:
236+
inputs = [
237+
self.truncate_context(
238+
constant_kwargs["model"], input, self.truncation_keep_strategy
239+
)
240+
for input in inputs
241+
]
242+
243+
# if kwargs are not the same for every input we cannot batch them
244+
if per_row_kwargs:
245+
246+
async def embed_single(input, kwargs) -> np.ndarray:
247+
kwargs = {**constant_kwargs, **kwargs}
248+
ret = await self.client.embeddings.create(input=[input], **kwargs) # type: ignore
249+
return np.array(ret.data[0].embedding)
250+
251+
list_of_per_row_kwargs = [
252+
dict(zip(per_row_kwargs, values))
253+
for values in zip(*per_row_kwargs.values())
254+
]
255+
async with asyncio.TaskGroup() as tg:
256+
tasks = [
257+
tg.create_task(embed_single(input, kwargs))
258+
for input, kwargs in zip(inputs, list_of_per_row_kwargs)
259+
]
204260

205-
ret = await self.client.embeddings.create(input=[input], **kwargs)
206-
return np.array(ret.data[0].embedding)
261+
result_list = [task.result() for task in tasks]
262+
return result_list
263+
264+
else:
265+
ret = await self.client.embeddings.create(input=inputs, **constant_kwargs)
266+
return [np.array(datum.embedding) for datum in ret.data]
207267

208268
@staticmethod
209269
def truncate_context(
@@ -250,6 +310,18 @@ def truncate_context(
250310

251311
return tokenizer.decode(tokens)
252312

313+
def get_embedding_dimension(self, **kwargs):
314+
"""Computes number of embedder's dimensions by asking the embedder to embed ``"."``.
315+
316+
Args:
317+
**kwargs: parameters of the embedder, if unset defaults from the constructor
318+
will be taken.
319+
"""
320+
kwargs_as_list = {k: [v] for k, v in kwargs.items()}
321+
n_dimensions = len(_coerce_sync(self.__wrapped__)(["."], **kwargs_as_list)[0])
322+
self.client = None
323+
return n_dimensions
324+
253325

254326
class LiteLLMEmbedder(BaseEmbedder):
255327
"""Pathway wrapper for `litellm.embedding`.
@@ -406,16 +478,7 @@ def __wrapped__(self, input: list[str], **kwargs) -> list[np.ndarray]:
406478
""" # noqa: E501
407479

408480
kwargs = _extract_value_inside_dict(kwargs)
409-
constant_kwargs = {}
410-
per_row_kwargs = {}
411-
412-
if kwargs:
413-
for key, values in kwargs.items():
414-
v = values[0]
415-
if all(value == v for value in values):
416-
constant_kwargs[key] = v
417-
else:
418-
per_row_kwargs[key] = values
481+
constant_kwargs, per_row_kwargs = _split_batched_kwargs(kwargs)
419482

420483
# if kwargs are not the same for every input we cannot batch them
421484
if per_row_kwargs:

0 commit comments

Comments
 (0)