Skip to content

Commit 7a07024

Browse files
authored
Merge pull request #2 from predictionguard/fix-batch-embeddings
fix batch embeddings above 32
2 parents 699abc0 + 8f0e17b commit 7a07024

1 file changed

Lines changed: 27 additions & 17 deletions

File tree

langchain_predictionguard/PredictionGuardEmbeddings.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -70,23 +70,33 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
7070
Embeddings for the texts.
7171
"""
7272

73-
inputs = []
74-
for text in texts:
75-
input = {"text": text}
76-
inputs.append(input)
77-
78-
response = self.client.embeddings.create(model=self.model, input=inputs)
79-
80-
res = []
81-
indx = 0
82-
for re in response["data"]:
83-
if re["index"] == indx:
84-
res.append(re["embedding"])
85-
indx += 1
86-
else:
87-
continue
88-
89-
return res
73+
max_batch_size = 30
74+
embeddings = []
75+
if len(texts) < max_batch_size:
76+
response = self.client.embeddings.create(
77+
model=self.model,
78+
input=texts,
79+
truncate=True
80+
)
81+
for idx in range(0, len(response["data"])):
82+
for emb in response["data"]:
83+
if emb['index'] == idx:
84+
embeddings.append(emb['embedding'])
85+
else:
86+
smaller_batches = [texts[i:i+max_batch_size] for i in range(0, len(texts), max_batch_size)]
87+
embeddings = []
88+
for smaller_batch in smaller_batches:
89+
response = self.client.embeddings.create(
90+
model=self.model,
91+
input=smaller_batch,
92+
truncate=True
93+
)
94+
for idx in range(0, len(response["data"])):
95+
for emb in response["data"]:
96+
if emb['index'] == idx:
97+
embeddings.append(emb['embedding'])
98+
99+
return embeddings
90100

91101
def embed_query(self, text: str) -> List[float]:
92102
"""Call out to Prediction Guard's embedding endpoint for embedding query text.

0 commit comments

Comments
 (0)