Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/cyan-seahorses-admire.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
'@ai-sdk/provider': patch
'ai': patch
---

This release introduces `wrapEmbeddingModel`, a new helper that brings embedding model customization capabilities similar to `wrapLanguageModel`.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { google } from '@ai-sdk/google';
import {
customProvider,
defaultEmbeddingSettingsMiddleware,
embedMany,
embed,
wrapEmbeddingModel,
} from 'ai';
import 'dotenv/config';

const centralSpace = customProvider({
textEmbeddingModels: {
'powerful-embedding-model': wrapEmbeddingModel({
model: google.textEmbedding('gemini-embedding-001'),
middleware: defaultEmbeddingSettingsMiddleware({
settings: {
providerOptions: {
google: {
outputDimensionality: 256,
taskType: 'CLASSIFICATION',
},
},
},
}),
}),
},
});
async function main() {
const embedManyResponse = await embedMany({
model: centralSpace.textEmbeddingModel('powerful-embedding-model'),
values: [
'sunny day at the beach',
'rainy afternoon in the city',
'snowy night in the mountains',
],
});
console.log(embedManyResponse.embeddings);

const response = await embed({
model: centralSpace.textEmbeddingModel('powerful-embedding-model'),
value: 'rainy afternoon in the city',
});
console.log(response.embedding);
}

main().catch(console.error);
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import { EmbeddingModelCallOptions } from '@ai-sdk/provider';
import { defaultEmbeddingSettingsMiddleware } from './default-embedding-settings-middleware';
import { MockEmbeddingModelV3 } from '../test/mock-embedding-model-v3';
import { describe, it, expect } from 'vitest';

const params: EmbeddingModelCallOptions<string> = {
values: ['hello world'],
};

const mockModel = new MockEmbeddingModelV3();

describe('headers', () => {
it('should merge headers', async () => {
const middleware = defaultEmbeddingSettingsMiddleware({
settings: {
headers: { 'X-Custom-Header': 'test', 'X-Another-Header': 'test2' },
},
});
const result = await middleware.transformParams!({
params: {
...params,
headers: { 'X-Custom-Header': 'test2' },
},
model: mockModel,
});
expect(result.headers).toEqual({
'X-Custom-Header': 'test2',
'X-Another-Header': 'test2',
});
});

it('should handle empty default headers', async () => {
const middleware = defaultEmbeddingSettingsMiddleware({
settings: { headers: {} },
});
const result = await middleware.transformParams!({
params: { ...params, headers: { 'X-Param-Header': 'param' } },
model: mockModel,
});
expect(result.headers).toEqual({ 'X-Param-Header': 'param' });
});

it('should handle empty param headers', async () => {
const middleware = defaultEmbeddingSettingsMiddleware({
settings: { headers: { 'X-Default-Header': 'default' } },
});
const result = await middleware.transformParams!({
params: { ...params, headers: {} },
model: mockModel,
});
expect(result.headers).toEqual({ 'X-Default-Header': 'default' });
});

it('should handle both headers being undefined', async () => {
const middleware = defaultEmbeddingSettingsMiddleware({
settings: {},
});
const result = await middleware.transformParams!({
params: { ...params },
model: mockModel,
});
expect(result.headers).toBeUndefined();
});
});

describe('providerOptions', () => {
it('should handle empty default providerOptions', async () => {
const middleware = defaultEmbeddingSettingsMiddleware({
settings: {
providerOptions: {},
},
});
const result = await middleware.transformParams!({
params: {
...params,
providerOptions: {
google: {
outputDimensionality: 512,
taskType: 'SEMANTIC_SIMILARITY',
},
},
},
model: mockModel,
});
expect(result.providerOptions).toEqual({
google: {
outputDimensionality: 512,
taskType: 'SEMANTIC_SIMILARITY',
},
});
});

it('should handle empty param providerOptions', async () => {
const middleware = defaultEmbeddingSettingsMiddleware({
settings: {
providerOptions: {
google: {
outputDimensionality: 512,
taskType: 'SEMANTIC_SIMILARITY',
},
},
},
});
const result = await middleware.transformParams!({
params: { ...params, providerOptions: {} },
model: mockModel,
});
expect(result.providerOptions).toEqual({
google: {
outputDimensionality: 512,
taskType: 'SEMANTIC_SIMILARITY',
},
});
});

it('should handle both providerOptions being undefined', async () => {
const middleware = defaultEmbeddingSettingsMiddleware({
settings: {},
});
const result = await middleware.transformParams!({
params: { ...params },
model: mockModel,
});
expect(result.providerOptions).toBeUndefined();
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import { EmbeddingModelCallOptions } from '@ai-sdk/provider';
import { EmbeddingModelMiddleware } from '../types';
import { mergeObjects } from '../util/merge-objects';

/**
* Applies default settings for a embedding model.
*/
export function defaultEmbeddingSettingsMiddleware({
settings,
}: {
settings: Partial<{
headers?: EmbeddingModelCallOptions<string>['headers'];
providerOptions?: EmbeddingModelCallOptions<string>['providerOptions'];
}>;
}): EmbeddingModelMiddleware {
return {
middlewareVersion: 'v3',
transformParams: async ({ params }) => {
return mergeObjects(
settings,
params,
) as EmbeddingModelCallOptions<string>;
},
};
}
2 changes: 2 additions & 0 deletions packages/ai/src/middleware/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
export { defaultEmbeddingSettingsMiddleware } from './default-embedding-settings-middleware';
export { defaultSettingsMiddleware } from './default-settings-middleware';
export { extractReasoningMiddleware } from './extract-reasoning-middleware';
export { simulateStreamingMiddleware } from './simulate-streaming-middleware';
export { wrapLanguageModel } from './wrap-language-model';
export { wrapEmbeddingModel } from './wrap-embedding-model';
export { wrapProvider } from './wrap-provider';
Loading