44"""
55import asyncio
66import logging
7- from typing import Literal
7+ from typing import Any , Literal
88
99import numpy as np
1010
@@ -85,6 +85,23 @@ def __call__(
8585 return super ().__call__ (input , * args , ** kwargs )
8686
8787
88+ def _split_batched_kwargs (
89+ kwargs : dict [str , list [Any ]]
90+ ) -> tuple [dict [str , Any ], dict [str , list [Any ]]]:
91+ constant_kwargs = {}
92+ per_row_kwargs = {}
93+
94+ if kwargs :
95+ for key , values in kwargs .items ():
96+ v = values [0 ]
97+ if all (value == v for value in values ):
98+ constant_kwargs [key ] = v
99+ else :
100+ per_row_kwargs [key ] = values
101+
102+ return constant_kwargs , per_row_kwargs
103+
104+
88105class OpenAIEmbedder (BaseEmbedder ):
89106 """Pathway wrapper for OpenAI Embedding services.
90107
@@ -114,6 +131,8 @@ class OpenAIEmbedder(BaseEmbedder):
114131 Can be ``"start"``, ``"end"`` or ``None``. ``"start"`` will keep the first part of the text
115132 and remove the rest. ``"end"`` will keep the last part of the text.
116133 If `None`, no truncation will be applied to any of the documents, this may cause API exceptions.
134+ batch_size: maximum size of a single batch to be sent to the embedder. Bigger
135+ batches may reduce the time needed for embedding.
117136 encoding_format: The format to return the embeddings in. Can be either `float` or
118137 `base64 <https://pypi.org/project/pybase64/>`_.
119138 user: A unique identifier representing your end-user, which can help OpenAI to monitor
@@ -160,6 +179,7 @@ def __init__(
160179 cache_strategy : udfs .CacheStrategy | None = None ,
161180 model : str | None = "text-embedding-3-small" ,
162181 truncation_keep_strategy : Literal ["start" , "end" ] | None = "start" ,
182+ batch_size : int = 128 ,
163183 ** openai_kwargs ,
164184 ):
165185 with optional_imports ("xpack-llm" ):
@@ -168,42 +188,82 @@ def __init__(
168188 _monkeypatch_openai_async ()
169189 executor = udfs .async_executor (capacity = capacity , retry_strategy = retry_strategy )
170190 super ().__init__ (
171- executor = executor ,
172- cache_strategy = cache_strategy ,
191+ executor = executor , cache_strategy = cache_strategy , max_batch_size = batch_size
173192 )
174193 self .truncation_keep_strategy = truncation_keep_strategy
175194 self .kwargs = dict (openai_kwargs )
176- api_key = self .kwargs .pop ("api_key" , None )
177- self .client = openai .AsyncOpenAI (api_key = api_key , max_retries = 0 )
195+ self .api_key = self .kwargs .pop ("api_key" , None )
196+ self .client : openai .AsyncOpenAI | None = None
197+
198+ # Initialization of OpenAI for the purpose of checking if api_key was provided
199+ # Actual initialization of the client is delayed to __wrapped__ to avoid issues
200+ # with the event loop.
201+ _ = openai .AsyncOpenAI (api_key = self .api_key , max_retries = 0 )
178202 if model is not None :
179203 self .kwargs ["model" ] = model
180204
181- async def __wrapped__ (self , input , ** kwargs ) -> np .ndarray :
205+ async def __wrapped__ (self , inputs : list [ str ] , ** kwargs ) -> list [ np .ndarray ] :
182206 """Embed the documents
183207
184208 Args:
185- input : mandatory, the string to embed.
209+ inputs : mandatory, the strings to embed.
186210 **kwargs: optional parameters, if unset defaults from the constructor
187211 will be taken.
188- """
189- input = input or "."
212+ #"""
213+ import openai
214+
215+ if self .client is None :
216+ self .client = openai .AsyncOpenAI (api_key = self .api_key , max_retries = 0 )
190217
191- kwargs = {** self .kwargs , ** kwargs }
192218 kwargs = _extract_value_inside_dict (kwargs )
193219
194- if kwargs .get ("model" ) is None :
220+ if kwargs .get ("model" ) is None and self . kwargs . get ( "model" ) is None :
195221 raise ValueError (
196222 "`model` parameter is missing in `OpenAIEmbedder`. "
197223 "Please provide the model name either in the constructor or in the function call."
198224 )
199225
226+ constant_kwargs , per_row_kwargs = _split_batched_kwargs (kwargs )
227+ constant_kwargs = {** self .kwargs , ** constant_kwargs }
228+
200229 if self .truncation_keep_strategy :
201- input = self .truncate_context (
202- kwargs ["model" ], input , self .truncation_keep_strategy
203- )
230+ if "model" in per_row_kwargs :
231+ inputs = [
232+ self .truncate_context (model , input , self .truncation_keep_strategy )
233+ for (model , input ) in zip (per_row_kwargs ["model" ], inputs )
234+ ]
235+ else :
236+ inputs = [
237+ self .truncate_context (
238+ constant_kwargs ["model" ], input , self .truncation_keep_strategy
239+ )
240+ for input in inputs
241+ ]
242+
243+ # if kwargs are not the same for every input we cannot batch them
244+ if per_row_kwargs :
245+
246+ async def embed_single (input , kwargs ) -> np .ndarray :
247+ kwargs = {** constant_kwargs , ** kwargs }
248+ ret = await self .client .embeddings .create (input = [input ], ** kwargs ) # type: ignore
249+ return np .array (ret .data [0 ].embedding )
250+
251+ list_of_per_row_kwargs = [
252+ dict (zip (per_row_kwargs , values ))
253+ for values in zip (* per_row_kwargs .values ())
254+ ]
255+ async with asyncio .TaskGroup () as tg :
256+ tasks = [
257+ tg .create_task (embed_single (input , kwargs ))
258+ for input , kwargs in zip (inputs , list_of_per_row_kwargs )
259+ ]
204260
205- ret = await self .client .embeddings .create (input = [input ], ** kwargs )
206- return np .array (ret .data [0 ].embedding )
261+ result_list = [task .result () for task in tasks ]
262+ return result_list
263+
264+ else :
265+ ret = await self .client .embeddings .create (input = inputs , ** constant_kwargs )
266+ return [np .array (datum .embedding ) for datum in ret .data ]
207267
208268 @staticmethod
209269 def truncate_context (
@@ -250,6 +310,18 @@ def truncate_context(
250310
251311 return tokenizer .decode (tokens )
252312
313+ def get_embedding_dimension (self , ** kwargs ):
314+ """Computes number of embedder's dimensions by asking the embedder to embed ``"."``.
315+
316+ Args:
317+ **kwargs: parameters of the embedder, if unset defaults from the constructor
318+ will be taken.
319+ """
320+ kwargs_as_list = {k : [v ] for k , v in kwargs .items ()}
321+ n_dimensions = len (_coerce_sync (self .__wrapped__ )(["." ], ** kwargs_as_list )[0 ])
322+ self .client = None
323+ return n_dimensions
324+
253325
254326class LiteLLMEmbedder (BaseEmbedder ):
255327 """Pathway wrapper for `litellm.embedding`.
@@ -406,16 +478,7 @@ def __wrapped__(self, input: list[str], **kwargs) -> list[np.ndarray]:
406478 """ # noqa: E501
407479
408480 kwargs = _extract_value_inside_dict (kwargs )
409- constant_kwargs = {}
410- per_row_kwargs = {}
411-
412- if kwargs :
413- for key , values in kwargs .items ():
414- v = values [0 ]
415- if all (value == v for value in values ):
416- constant_kwargs [key ] = v
417- else :
418- per_row_kwargs [key ] = values
481+ constant_kwargs , per_row_kwargs = _split_batched_kwargs (kwargs )
419482
420483 # if kwargs are not the same for every input we cannot batch them
421484 if per_row_kwargs :
0 commit comments