Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/cyan-seahorses-admire.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
'@ai-sdk/provider': patch
'ai': patch
---

This release introduces `wrapEmbeddingModel`, a new helper that brings embedding model customization capabilities similar to `wrapLanguageModel`.
31 changes: 31 additions & 0 deletions content/docs/03-ai-sdk-core/30-embeddings.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,37 @@ const { embedding, response } = await embed({
console.log(response); // Raw provider response
```

## Embedding Middleware

You can enhance embedding models, e.g. to set default values, using
`wrapEmbeddingModel` and `EmbeddingModelV3Middleware`.

Here is an example that uses the built-in `defaultEmbeddingSettingsMiddleware`:

```ts
import { google } from '@ai-sdk/google';
import {
customProvider,
defaultEmbeddingSettingsMiddleware,
embed,
wrapEmbeddingModel,
} from 'ai';

const embeddingModelWithDefaults = wrapEmbeddingModel({
model: google.textEmbedding('gemini-embedding-001'),
middleware: defaultEmbeddingSettingsMiddleware({
settings: {
providerOptions: {
google: {
outputDimensionality: 256,
taskType: 'CLASSIFICATION',
},
},
},
}),
});
```

## Embedding Providers & Models

Several providers offer embedding models:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { generateText, wrapLanguageModel } from 'ai';
import 'dotenv/config';

const logProviderMetadataMiddleware: LanguageModelV3Middleware = {
specificationVersion: 'v3',
transformParams: async ({ params }) => {
console.log(
'providerOptions: ' + JSON.stringify(params.providerOptions, null, 2),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { google } from '@ai-sdk/google';
import {
customProvider,
defaultEmbeddingSettingsMiddleware,
embed,
wrapEmbeddingModel,
} from 'ai';
import { print } from '../../lib/print';
import { run } from '../../lib/run';

const custom = customProvider({
textEmbeddingModels: {
'powerful-embedding-model': wrapEmbeddingModel({
model: google.textEmbedding('gemini-embedding-001'),
middleware: defaultEmbeddingSettingsMiddleware({
settings: {
providerOptions: {
google: {
outputDimensionality: 256,
taskType: 'CLASSIFICATION',
},
},
},
}),
}),
},
});

run(async () => {
const result = await embed({
model: custom.textEmbeddingModel('powerful-embedding-model'),
value: 'rainy afternoon in the city',
});

print('Embedding length:', result.embedding.length);
});
1 change: 1 addition & 0 deletions examples/ai-core/src/middleware/your-cache-middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { LanguageModelV3Middleware } from '@ai-sdk/provider';
const cache = new Map<string, any>();

export const yourCacheMiddleware: LanguageModelV3Middleware = {
specificationVersion: 'v3',
wrapGenerate: async ({ doGenerate, params }) => {
const cacheKey = JSON.stringify(params);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
} from '@ai-sdk/provider';

export const yourGuardrailMiddleware: LanguageModelV3Middleware = {
specificationVersion: 'v3',
wrapGenerate: async ({ doGenerate }) => {
const { content, ...rest } = await doGenerate();

Expand Down
1 change: 1 addition & 0 deletions examples/ai-core/src/middleware/your-log-middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
} from '@ai-sdk/provider';

export const yourLogMiddleware: LanguageModelV3Middleware = {
specificationVersion: 'v3',
wrapGenerate: async ({ doGenerate, params }) => {
console.log('doGenerate called');
console.log(`params: ${JSON.stringify(params, null, 2)}`);
Expand Down
1 change: 1 addition & 0 deletions examples/ai-core/src/middleware/your-rag-middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { addToLastUserMessage } from './add-to-last-user-message';
import { getLastUserMessageText } from './get-last-user-message-text';

export const yourRagMiddleware: LanguageModelV3Middleware = {
specificationVersion: 'v3',
transformParams: async ({ params }) => {
const lastUserMessageText = getLastUserMessageText({
prompt: params.prompt,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import { EmbeddingModelCallOptions } from '@ai-sdk/provider';
import { defaultEmbeddingSettingsMiddleware } from './default-embedding-settings-middleware';
import { MockEmbeddingModelV3 } from '../test/mock-embedding-model-v3';
import { describe, it, expect } from 'vitest';

const params: EmbeddingModelCallOptions<string> = {
values: ['hello world'],
};

const mockModel = new MockEmbeddingModelV3();

describe('headers', () => {
it('should merge headers', async () => {
const middleware = defaultEmbeddingSettingsMiddleware({
settings: {
headers: { 'X-Custom-Header': 'test', 'X-Another-Header': 'test2' },
},
});
const result = await middleware.transformParams!({
params: {
...params,
headers: { 'X-Custom-Header': 'test2' },
},
model: mockModel,
});
expect(result.headers).toEqual({
'X-Custom-Header': 'test2',
'X-Another-Header': 'test2',
});
});

it('should handle empty default headers', async () => {
const middleware = defaultEmbeddingSettingsMiddleware({
settings: { headers: {} },
});
const result = await middleware.transformParams!({
params: { ...params, headers: { 'X-Param-Header': 'param' } },
model: mockModel,
});
expect(result.headers).toEqual({ 'X-Param-Header': 'param' });
});

it('should handle empty param headers', async () => {
const middleware = defaultEmbeddingSettingsMiddleware({
settings: { headers: { 'X-Default-Header': 'default' } },
});
const result = await middleware.transformParams!({
params: { ...params, headers: {} },
model: mockModel,
});
expect(result.headers).toEqual({ 'X-Default-Header': 'default' });
});

it('should handle both headers being undefined', async () => {
const middleware = defaultEmbeddingSettingsMiddleware({
settings: {},
});
const result = await middleware.transformParams!({
params: { ...params },
model: mockModel,
});
expect(result.headers).toBeUndefined();
});
});

describe('providerOptions', () => {
it('should handle empty default providerOptions', async () => {
const middleware = defaultEmbeddingSettingsMiddleware({
settings: {
providerOptions: {},
},
});
const result = await middleware.transformParams!({
params: {
...params,
providerOptions: {
google: {
outputDimensionality: 512,
taskType: 'SEMANTIC_SIMILARITY',
},
},
},
model: mockModel,
});
expect(result.providerOptions).toEqual({
google: {
outputDimensionality: 512,
taskType: 'SEMANTIC_SIMILARITY',
},
});
});

it('should handle empty param providerOptions', async () => {
const middleware = defaultEmbeddingSettingsMiddleware({
settings: {
providerOptions: {
google: {
outputDimensionality: 512,
taskType: 'SEMANTIC_SIMILARITY',
},
},
},
});
const result = await middleware.transformParams!({
params: { ...params, providerOptions: {} },
model: mockModel,
});
expect(result.providerOptions).toEqual({
google: {
outputDimensionality: 512,
taskType: 'SEMANTIC_SIMILARITY',
},
});
});

it('should handle both providerOptions being undefined', async () => {
const middleware = defaultEmbeddingSettingsMiddleware({
settings: {},
});
const result = await middleware.transformParams!({
params: { ...params },
model: mockModel,
});
expect(result.providerOptions).toBeUndefined();
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import { EmbeddingModelCallOptions } from '@ai-sdk/provider';
import { EmbeddingModelMiddleware } from '../types';
import { mergeObjects } from '../util/merge-objects';

/**
* Applies default settings for a embedding model.
*/
export function defaultEmbeddingSettingsMiddleware({
settings,
}: {
settings: Partial<{
headers?: EmbeddingModelCallOptions<string>['headers'];
providerOptions?: EmbeddingModelCallOptions<string>['providerOptions'];
}>;
}): EmbeddingModelMiddleware {
return {
specificationVersion: 'v3',
transformParams: async ({ params }) => {
return mergeObjects(
settings,
params,
) as EmbeddingModelCallOptions<string>;
},
};
}
2 changes: 1 addition & 1 deletion packages/ai/src/middleware/default-settings-middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export function defaultSettingsMiddleware({
}>;
}): LanguageModelMiddleware {
return {
middlewareVersion: 'v3',
specificationVersion: 'v3',
transformParams: async ({ params }) => {
return mergeObjects(settings, params) as LanguageModelV3CallOptions;
},
Expand Down
2 changes: 1 addition & 1 deletion packages/ai/src/middleware/extract-reasoning-middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export function extractReasoningMiddleware({
const closingTag = `<\/${tagName}>`;

return {
middlewareVersion: 'v3',
specificationVersion: 'v3',
wrapGenerate: async ({ doGenerate }) => {
const { content, ...rest } = await doGenerate();

Expand Down
2 changes: 2 additions & 0 deletions packages/ai/src/middleware/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
export { defaultEmbeddingSettingsMiddleware } from './default-embedding-settings-middleware';
export { defaultSettingsMiddleware } from './default-settings-middleware';
export { extractReasoningMiddleware } from './extract-reasoning-middleware';
export { simulateStreamingMiddleware } from './simulate-streaming-middleware';
export { wrapLanguageModel } from './wrap-language-model';
export { wrapEmbeddingModel } from './wrap-embedding-model';
export { wrapProvider } from './wrap-provider';
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { LanguageModelMiddleware } from '../types';
*/
export function simulateStreamingMiddleware(): LanguageModelMiddleware {
return {
middlewareVersion: 'v3',
specificationVersion: 'v3',
wrapStream: async ({ doGenerate }) => {
const result = await doGenerate();

Expand Down
Loading
Loading