Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/fast-plants-wink.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@langchain/openai": patch
"@langchain/core": patch
---

support base64 embeddings format
14 changes: 8 additions & 6 deletions langchain-core/src/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,32 @@ import { AsyncCaller, AsyncCallerParams } from "./utils/async_caller.js";
*/
export type EmbeddingsParams = AsyncCallerParams;

export interface EmbeddingsInterface {
export interface EmbeddingsInterface<TOutput = number[]> {
/**
* An abstract method that takes an array of documents as input and
* returns a promise that resolves to an array of vectors for each
* document.
* @param documents An array of documents to be embedded.
* @returns A promise that resolves to an array of vectors for each document.
*/
embedDocuments(documents: string[]): Promise<number[][]>;
embedDocuments(documents: string[]): Promise<TOutput[]>;

/**
* An abstract method that takes a single document as input and returns a
* promise that resolves to a vector for the query document.
* @param document A single document to be embedded.
* @returns A promise that resolves to a vector for the query document.
*/
embedQuery(document: string): Promise<number[]>;
embedQuery(document: string): Promise<TOutput>;
}

/**
* An abstract class that provides methods for embedding documents and
* queries using LangChain.
*/
export abstract class Embeddings implements EmbeddingsInterface {
export abstract class Embeddings<TOutput = number[]>
implements EmbeddingsInterface<TOutput>
{
/**
* The async caller should be used by subclasses to make any async calls,
* which will thus benefit from the concurrency and retry logic.
Expand All @@ -47,13 +49,13 @@ export abstract class Embeddings implements EmbeddingsInterface {
* @param documents An array of documents to be embedded.
* @returns A promise that resolves to an array of vectors for each document.
*/
abstract embedDocuments(documents: string[]): Promise<number[][]>;
abstract embedDocuments(documents: string[]): Promise<TOutput[]>;

/**
* An abstract method that takes a single document as input and returns a
* promise that resolves to a vector for the query document.
* @param document A single document to be embedded.
* @returns A promise that resolves to a vector for the query document.
*/
abstract embedQuery(document: string): Promise<number[]>;
abstract embedQuery(document: string): Promise<TOutput>;
}
32 changes: 23 additions & 9 deletions libs/langchain-openai/src/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import { type ClientOptions, OpenAI as OpenAIClient } from "openai";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
import { chunkArray } from "@langchain/core/utils/chunk_array";
import { OpenAICoreRequestOptions } from "./types.js";
import { getEndpoint, OpenAIEndpointConfig } from "./utils/azure.js";
import { wrapOpenAIClientError } from "./utils/openai.js";

Expand Down Expand Up @@ -50,6 +49,11 @@ export interface OpenAIEmbeddingsParams extends EmbeddingsParams {
* See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
*/
stripNewLines?: boolean;

/**
* The format to return the embeddings in. Can be either 'float' or 'base64'.
*/
encodingFormat?: "float" | "base64";
}

/**
Expand All @@ -68,8 +72,8 @@ export interface OpenAIEmbeddingsParams extends EmbeddingsParams {
*
* ```
*/
export class OpenAIEmbeddings
extends Embeddings
export class OpenAIEmbeddings<TOutput = number[]>
extends Embeddings<TOutput>
implements Partial<OpenAIEmbeddingsParams>
{
model = "text-embedding-ada-002";
Expand All @@ -92,6 +96,8 @@ export class OpenAIEmbeddings

organization?: string;

encodingFormat?: "float" | "base64";

protected client: OpenAIClient;

protected clientConfig: ClientOptions;
Expand Down Expand Up @@ -130,6 +136,7 @@ export class OpenAIEmbeddings
fieldsWithDefaults?.stripNewLines ?? this.stripNewLines;
this.timeout = fieldsWithDefaults?.timeout;
this.dimensions = fieldsWithDefaults?.dimensions;
this.encodingFormat = fieldsWithDefaults?.encodingFormat;

this.clientConfig = {
apiKey,
Expand All @@ -146,7 +153,7 @@ export class OpenAIEmbeddings
* @param texts Array of documents to generate embeddings for.
* @returns Promise that resolves to a 2D array of embeddings for each document.
*/
async embedDocuments(texts: string[]): Promise<number[][]> {
async embedDocuments(texts: string[]): Promise<TOutput[]> {
const batches = chunkArray(
this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts,
this.batchSize
Expand All @@ -160,16 +167,19 @@ export class OpenAIEmbeddings
if (this.dimensions) {
params.dimensions = this.dimensions;
}
if (this.encodingFormat) {
params.encoding_format = this.encodingFormat;
}
return this.embeddingWithRetry(params);
});
const batchResponses = await Promise.all(batchRequests);

const embeddings: number[][] = [];
const embeddings: TOutput[] = [];
for (let i = 0; i < batchResponses.length; i += 1) {
const batch = batches[i];
const { data: batchResponse } = batchResponses[i];
for (let j = 0; j < batch.length; j += 1) {
embeddings.push(batchResponse[j].embedding);
embeddings.push(batchResponse[j].embedding as TOutput);
}
}
return embeddings;
Expand All @@ -181,16 +191,19 @@ export class OpenAIEmbeddings
* @param text Document to generate an embedding for.
* @returns Promise that resolves to an embedding for the document.
*/
async embedQuery(text: string): Promise<number[]> {
async embedQuery(text: string): Promise<TOutput> {
const params: OpenAIClient.EmbeddingCreateParams = {
model: this.model,
input: this.stripNewLines ? text.replace(/\n/g, " ") : text,
};
if (this.dimensions) {
params.dimensions = this.dimensions;
}
if (this.encodingFormat) {
params.encoding_format = this.encodingFormat;
}
const { data } = await this.embeddingWithRetry(params);
return data[0].embedding;
return data[0].embedding as TOutput;
}

/**
Expand Down Expand Up @@ -223,7 +236,8 @@ export class OpenAIEmbeddings

this.client = new OpenAIClient(params);
}
const requestOptions: OpenAICoreRequestOptions = {};
const requestOptions = {};

return this.caller.call(async () => {
try {
const res = await this.client.embeddings.create(
Expand Down
34 changes: 34 additions & 0 deletions libs/langchain-openai/src/tests/embeddings.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,37 @@ test("Test OpenAIEmbeddings.embedDocuments with v3 and dimensions", async () =>
expect(res[0].length).toBe(127);
expect(res[1].length).toBe(127);
});

test("Test OpenAIEmbeddings.embedQuery with encodingFormat", async () => {
const embeddings = new OpenAIEmbeddings({
modelName: "text-embedding-3-small",
encodingFormat: "float",
});
const res = await embeddings.embedQuery("Hello world");
expect(typeof res[0]).toBe("number");
expect(res.length).toBe(1536); // Default dimension for text-embedding-3-small
});

test("Test OpenAIEmbeddings.embedDocuments with encodingFormat", async () => {
const embeddings = new OpenAIEmbeddings({
modelName: "text-embedding-3-small",
encodingFormat: "float",
});
const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]);
expect(res).toHaveLength(2);
expect(typeof res[0][0]).toBe("number");
expect(typeof res[1][0]).toBe("number");
expect(res[0].length).toBe(1536); // Default dimension for text-embedding-3-small
expect(res[1].length).toBe(1536);
});

test("Test OpenAIEmbeddings with encodingFormat and custom dimensions", async () => {
const embeddings = new OpenAIEmbeddings({
modelName: "text-embedding-3-small",
encodingFormat: "float",
dimensions: 256,
});
const res = await embeddings.embedQuery("Hello world");
expect(typeof res[0]).toBe("number");
expect(res.length).toBe(256); // Should respect custom dimensions
});
Loading