diff --git a/packages/workers-ai-provider/package.json b/packages/workers-ai-provider/package.json index 0eb26a7f..67828df4 100644 --- a/packages/workers-ai-provider/package.json +++ b/packages/workers-ai-provider/package.json @@ -2,7 +2,7 @@ "name": "workers-ai-provider", "description": "Workers AI Provider for the vercel AI SDK", "type": "module", - "version": "0.5.2", + "version": "0.5.3-v5-beta", "main": "dist/index.js", "types": "dist/index.d.ts", "repository": { @@ -21,10 +21,25 @@ "test:ci": "vitest --watch=false", "test": "vitest" }, - "files": ["dist", "src", "README.md", "package.json"], - "keywords": ["workers", "cloudflare", "ai", "vercel", "sdk", "provider", "chat", "serverless"], + "files": [ + "dist", + "src", + "README.md", + "package.json" + ], + "keywords": [ + "workers", + "cloudflare", + "ai", + "vercel", + "sdk", + "provider", + "chat", + "serverless" + ], "dependencies": { - "@ai-sdk/provider": "^1.1.3" + "@ai-sdk/provider": "2.0.0-beta.1", + "@ai-sdk/provider-utils": "2.2.8" }, "devDependencies": { "@cloudflare/workers-types": "^4.20250525.0" diff --git a/packages/workers-ai-provider/src/autorag-chat-language-model.ts b/packages/workers-ai-provider/src/autorag-chat-language-model.ts index 40abacd4..4f6e093d 100644 --- a/packages/workers-ai-provider/src/autorag-chat-language-model.ts +++ b/packages/workers-ai-provider/src/autorag-chat-language-model.ts @@ -1,14 +1,12 @@ import { - type LanguageModelV1, - type LanguageModelV1CallWarning, - UnsupportedFunctionalityError, + type LanguageModelV2, + type LanguageModelV2CallWarning, } from "@ai-sdk/provider"; import type { AutoRAGChatSettings } from "./autorag-chat-settings"; import { convertToWorkersAIChatMessages } from "./convert-to-workersai-chat-messages"; import { mapWorkersAIUsage } from "./map-workersai-usage"; import { getMappedStream } from "./streaming"; -import { prepareToolsAndToolChoice, processToolCalls } from "./utils"; import type { TextGenerationModels } from "./workersai-models"; type AutoRAGChatConfig = { @@ -17,13 +15,15 @@ type AutoRAGChatConfig = { gateway?: GatewayOptions; }; -export class AutoRAGChatLanguageModel implements LanguageModelV1 { - readonly specificationVersion = "v1"; +export class AutoRAGChatLanguageModel implements LanguageModelV2 { + readonly specificationVersion = "v2"; readonly defaultObjectGenerationMode = "json"; readonly modelId: TextGenerationModels; readonly settings: AutoRAGChatSettings; + readonly supportedUrls = {} + private readonly config: AutoRAGChatConfig; constructor( @@ -41,14 +41,13 @@ export class AutoRAGChatLanguageModel implements LanguageModelV1 { } private getArgs({ - mode, prompt, frequencyPenalty, presencePenalty, - }: Parameters[0]) { - const type = mode.type; - - const warnings: LanguageModelV1CallWarning[] = []; + tools, + toolChoice, + }: Parameters[0]) { + const warnings: LanguageModelV2CallWarning[] = []; if (frequencyPenalty != null) { warnings.push({ @@ -72,58 +71,21 @@ export class AutoRAGChatLanguageModel implements LanguageModelV1 { messages: convertToWorkersAIChatMessages(prompt), }; - switch (type) { - case "regular": { - return { - args: { ...baseArgs, ...prepareToolsAndToolChoice(mode) }, - warnings, - }; - } - - case "object-json": { - return { - args: { - ...baseArgs, - response_format: { - type: "json_schema", - json_schema: mode.schema, - }, - tools: undefined, - }, - warnings, - }; - } - - case "object-tool": { - return { - args: { - ...baseArgs, - tool_choice: "any", - tools: [{ type: "function", function: mode.tool }], - }, - warnings, - }; - } - - // @ts-expect-error - this is unreachable code - // TODO: fixme - case "object-grammar": { - throw new UnsupportedFunctionalityError({ - functionality: "object-grammar mode", - }); - } - - default: { - const exhaustiveCheck = type satisfies never; - throw new Error(`Unsupported type: ${exhaustiveCheck}`); - } + return { + args: { + ...baseArgs, + tool_choice: toolChoice, + tools + }, + warnings, } + } async doGenerate( - options: Parameters[0], - ): Promise>> { - const { args, warnings } = this.getArgs(options); + options: Parameters[0], + ): Promise>> { + const { warnings } = this.getArgs(options); const { messages } = convertToWorkersAIChatMessages(options.prompt); @@ -131,28 +93,29 @@ export class AutoRAGChatLanguageModel implements LanguageModelV1 { query: messages.map(({ content, role }) => `${role}: ${content}`).join("\n\n"), }); + //@ts-ignore return { - text: output.response, - toolCalls: processToolCalls(output), + // content: output.response, + // toolCalls: processToolCalls(output), finishReason: "stop", // TODO: mapWorkersAIFinishReason(response.finish_reason), - rawCall: { rawPrompt: args.messages, rawSettings: args }, + // rawCall: { rawPrompt: args.messages, rawSettings: args }, usage: mapWorkersAIUsage(output), warnings, - sources: output.data.map(({ file_id, filename, score }) => ({ - id: file_id, - sourceType: "url", - url: filename, - providerMetadata: { - attributes: { score }, - }, - })), + // sources: output.data.map(({ file_id, filename, score }) => ({ + // id: file_id, + // sourceType: "url", + // url: filename, + // providerMetadata: { + // attributes: { score }, + // }, + // })), }; } async doStream( - options: Parameters[0], - ): Promise>> { - const { args, warnings } = this.getArgs(options); + options: Parameters[0], + ): Promise>> { + // const { args, warnings } = this.getArgs(options); const { messages } = convertToWorkersAIChatMessages(options.prompt); @@ -165,8 +128,8 @@ export class AutoRAGChatLanguageModel implements LanguageModelV1 { return { stream: getMappedStream(response), - rawCall: { rawPrompt: args.messages, rawSettings: args }, - warnings, + // rawCall: { rawPrompt: args.messages, rawSettings: args }, + // warnings, }; } } diff --git a/packages/workers-ai-provider/src/convert-to-workersai-chat-messages.ts b/packages/workers-ai-provider/src/convert-to-workersai-chat-messages.ts index 82a412ea..05e6fc14 100644 --- a/packages/workers-ai-provider/src/convert-to-workersai-chat-messages.ts +++ b/packages/workers-ai-provider/src/convert-to-workersai-chat-messages.ts @@ -1,25 +1,35 @@ -import type { LanguageModelV1Prompt, LanguageModelV1ProviderMetadata } from "@ai-sdk/provider"; +import type { LanguageModelV2Prompt, LanguageModelV2ProviderMetadata } from "@ai-sdk/provider"; import type { WorkersAIChatPrompt } from "./workersai-chat-prompt"; -export function convertToWorkersAIChatMessages(prompt: LanguageModelV1Prompt): { +export function convertToWorkersAIChatMessages(prompt: LanguageModelV2Prompt): { messages: WorkersAIChatPrompt; images: { mimeType: string | undefined; image: Uint8Array; - providerMetadata: LanguageModelV1ProviderMetadata | undefined; + providerMetadata: LanguageModelV2ProviderMetadata | undefined; }[]; } { const messages: WorkersAIChatPrompt = []; const images: { mimeType: string | undefined; image: Uint8Array; - providerMetadata: LanguageModelV1ProviderMetadata | undefined; + providerMetadata: LanguageModelV2ProviderMetadata | undefined; }[] = []; for (const { role, content } of prompt) { switch (role) { case "system": { - messages.push({ role: "system", content }); + messages.push({ + role: "system", + content: content + .map((part) => { + if (part.type === "text") { + return part.text; + } + return ""; + }) + .join("\n") + }); break; } @@ -95,10 +105,10 @@ export function convertToWorkersAIChatMessages(prompt: LanguageModelV1Prompt): { tool_calls: toolCalls.length > 0 ? toolCalls.map(({ function: { name, arguments: args } }) => ({ - id: "null", - type: "function", - function: { name, arguments: args }, - })) + id: "null", + type: "function", + function: { name, arguments: args }, + })) : undefined, }); diff --git a/packages/workers-ai-provider/src/index.ts b/packages/workers-ai-provider/src/index.ts index fbc50ac1..aed03aee 100644 --- a/packages/workers-ai-provider/src/index.ts +++ b/packages/workers-ai-provider/src/index.ts @@ -1,3 +1,6 @@ +import { + type ProviderV2, +} from '@ai-sdk/provider'; import { AutoRAGChatLanguageModel } from "./autorag-chat-language-model"; import type { AutoRAGChatSettings } from "./autorag-chat-settings"; import { createRun } from "./utils"; @@ -15,30 +18,31 @@ import type { TextGenerationModels, } from "./workersai-models"; + export type WorkersAISettings = ( | { - /** - * Provide a Cloudflare AI binding. - */ - binding: Ai; - - /** - * Credentials must be absent when a binding is given. - */ - accountId?: never; - apiKey?: never; - } + /** + * Provide a Cloudflare AI binding. + */ + binding: Ai; + + /** + * Credentials must be absent when a binding is given. + */ + accountId?: never; + apiKey?: never; + } | { - /** - * Provide Cloudflare API credentials directly. Must be used if a binding is not specified. - */ - accountId: string; - apiKey: string; - /** - * Both binding must be absent if credentials are used directly. - */ - binding?: never; - } + /** + * Provide Cloudflare API credentials directly. Must be used if a binding is not specified. + */ + accountId: string; + apiKey: string; + /** + * Both binding must be absent if credentials are used directly. + */ + binding?: never; + } ) & { /** * Optionally specify a gateway. @@ -46,11 +50,15 @@ export type WorkersAISettings = ( gateway?: GatewayOptions; }; -export interface WorkersAI { +export interface WorkersAI extends ProviderV2 { (modelId: TextGenerationModels, settings?: WorkersAIChatSettings): WorkersAIChatLanguageModel; /** * Creates a model for text generation. **/ + + /** + * @deprecated Use `.languageModel()` instead. + **/ chat( modelId: TextGenerationModels, settings?: WorkersAIChatSettings, @@ -73,8 +81,19 @@ export interface WorkersAI { /** * Creates a model for image generation. + * @deprecated use .imageModel() instead. **/ image(modelId: ImageGenerationModels, settings?: WorkersAIImageSettings): WorkersAIImageModel; + + /** + * Creates a model for text generation. + **/ + languageModel(modelId: TextGenerationModels, settings?: WorkersAIChatSettings): WorkersAIChatLanguageModel; + + /** + * Creates a model for image generation. + **/ + imageModel(modelId: string, settings?: WorkersAIImageSettings): WorkersAIImageModel; } /** @@ -83,6 +102,7 @@ export interface WorkersAI { export function createWorkersAI(options: WorkersAISettings): WorkersAI { // Use a binding if one is directly provided. Otherwise use credentials to create // a `run` method that calls the Cloudflare REST API. + console.log("Creating Workers AI provider with options:", options); let binding: Ai | undefined; if (options.binding) { @@ -131,7 +151,8 @@ export function createWorkersAI(options: WorkersAISettings): WorkersAI { return createChatModel(modelId, settings); }; - provider.chat = createChatModel; + provider.chat = createChatModel; // Deprecated alias for `languageModel` + provider.languageModel = createChatModel; provider.embedding = createEmbeddingModel; provider.textEmbedding = createEmbeddingModel; provider.textEmbeddingModel = createEmbeddingModel; diff --git a/packages/workers-ai-provider/src/map-workersai-finish-reason.ts b/packages/workers-ai-provider/src/map-workersai-finish-reason.ts index 02b5575e..0a2f9123 100644 --- a/packages/workers-ai-provider/src/map-workersai-finish-reason.ts +++ b/packages/workers-ai-provider/src/map-workersai-finish-reason.ts @@ -1,8 +1,8 @@ -import type { LanguageModelV1FinishReason } from "@ai-sdk/provider"; +import type { LanguageModelV2FinishReason } from "@ai-sdk/provider"; export function mapWorkersAIFinishReason( finishReason: string | null | undefined, -): LanguageModelV1FinishReason { +): LanguageModelV2FinishReason { switch (finishReason) { case "stop": return "stop"; diff --git a/packages/workers-ai-provider/src/map-workersai-usage.ts b/packages/workers-ai-provider/src/map-workersai-usage.ts index be7f8c98..327d3805 100644 --- a/packages/workers-ai-provider/src/map-workersai-usage.ts +++ b/packages/workers-ai-provider/src/map-workersai-usage.ts @@ -1,7 +1,9 @@ -export function mapWorkersAIUsage(output: AiTextGenerationOutput | AiTextToImageOutput) { +import type { LanguageModelV2Usage } from "@ai-sdk/provider" + +export function mapWorkersAIUsage(output: AiTextGenerationOutput | AiTextToImageOutput): LanguageModelV2Usage { const usage = ( output as { - usage: { prompt_tokens: number; completion_tokens: number }; + usage?: { prompt_tokens: number; completion_tokens: number }; } ).usage ?? { prompt_tokens: 0, @@ -9,7 +11,9 @@ export function mapWorkersAIUsage(output: AiTextGenerationOutput | AiTextToImage }; return { - promptTokens: usage.prompt_tokens, - completionTokens: usage.completion_tokens, + inputTokens: usage.prompt_tokens, + outputTokens: usage.completion_tokens, + totalTokens: usage.prompt_tokens + usage.completion_tokens, }; } + diff --git a/packages/workers-ai-provider/src/streaming.ts b/packages/workers-ai-provider/src/streaming.ts index a382a032..781db0af 100644 --- a/packages/workers-ai-provider/src/streaming.ts +++ b/packages/workers-ai-provider/src/streaming.ts @@ -1,24 +1,31 @@ import { events } from "fetch-event-stream"; -import type { LanguageModelV1StreamPart } from "@ai-sdk/provider"; +import type { LanguageModelV2StreamPart, LanguageModelV2Usage } from "@ai-sdk/provider"; import { mapWorkersAIUsage } from "./map-workersai-usage"; import { processPartialToolCalls } from "./utils"; export function getMappedStream(response: Response) { const chunkEvent = events(response); - let usage = { promptTokens: 0, completionTokens: 0 }; + let usage: LanguageModelV2Usage = { inputTokens: 0, outputTokens: 0, totalTokens: 0 } const partialToolCalls: any[] = []; + let textStarted = false; - return new ReadableStream({ + return new ReadableStream({ async start(controller) { + console.log('[STREAMING] Starting stream processing'); for await (const event of chunkEvent) { + console.log('[STREAMING] Event received:', event); if (!event.data) { + console.log('[STREAMING] No data in event, continuing'); continue; } if (event.data === "[DONE]") { + console.log('[STREAMING] Received [DONE], breaking'); break; } + console.log('[STREAMING] Raw event data:', event.data); const chunk = JSON.parse(event.data); + console.log('[STREAMING] Parsed chunk:', chunk); if (chunk.usage) { usage = mapWorkersAIUsage(chunk); } @@ -26,11 +33,32 @@ export function getMappedStream(response: Response) { partialToolCalls.push(...chunk.tool_calls); continue; } - chunk.response?.length && + if (chunk.response?.length) { + console.log('[STREAMING] Text chunk found:', chunk.response); + if (!textStarted) { + console.log('[STREAMING] Starting text stream'); + controller.enqueue({ + type: "text-start", + id: crypto.randomUUID(), + }); + textStarted = true; + } controller.enqueue({ type: "text-delta", - textDelta: chunk.response, + id: crypto.randomUUID(), + delta: chunk.response, }); + console.log('[STREAMING] Enqueued text-delta'); + } else { + console.log('[STREAMING] No response in chunk'); + } + } + + if (textStarted) { + controller.enqueue({ + type: "text-end", + id: crypto.randomUUID(), + }); } if (partialToolCalls.length > 0) { diff --git a/packages/workers-ai-provider/src/utils.ts b/packages/workers-ai-provider/src/utils.ts index 73759deb..fae6f06b 100644 --- a/packages/workers-ai-provider/src/utils.ts +++ b/packages/workers-ai-provider/src/utils.ts @@ -1,4 +1,12 @@ -import type { LanguageModelV1, LanguageModelV1FunctionToolCall } from "@ai-sdk/provider"; +import type { LanguageModelV2 } from "@ai-sdk/provider"; + +// Custom type to replace LanguageModelV1FunctionToolCall for v5 compatibility +interface FunctionToolCall { + toolCallType: string; + toolCallId: string; + toolName: string; + args: string; +} /** * General AI run interface with overloads to handle distinct return types. @@ -85,9 +93,8 @@ export function createRun(config: CreateRunConfig): AiRun { } } - const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}${ - urlParams ? `?${urlParams}` : "" - }`; + const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}${urlParams ? `?${urlParams}` : "" + }`; // Merge default and custom headers. const headers = { @@ -97,12 +104,18 @@ export function createRun(config: CreateRunConfig): AiRun { const body = JSON.stringify(inputs); + console.log('[createRun] Making request to:', url); + console.log('[createRun] With body:', body); + // Execute the POST request. The optional AbortSignal is applied here. const response = await fetch(url, { method: "POST", headers, body, }); + + console.log('[createRun] Response status:', response.status); + console.log('[createRun] Response headers:', Object.fromEntries(response.headers.entries())); // (1) If the user explicitly requests the raw Response, return it as-is. if (returnRawResponse) { @@ -126,7 +139,8 @@ export function createRun(config: CreateRunConfig): AiRun { } export function prepareToolsAndToolChoice( - mode: Parameters[0]["mode"] & { + //@ts-ignore + mode: Parameters[0]["mode"] & { type: "regular"; }, ) { @@ -137,13 +151,12 @@ export function prepareToolsAndToolChoice( return { tools: undefined, tool_choice: undefined }; } + //@ts-ignore const mappedTools = tools.map((tool) => ({ type: "function", function: { name: tool.name, - // @ts-expect-error - description is not a property of tool description: tool.description, - // @ts-expect-error - parameters is not a property of tool parameters: tool.parameters, }, })); @@ -168,10 +181,12 @@ export function prepareToolsAndToolChoice( // so we filter the tools and force the tool choice through 'any' case "tool": return { + //@ts-ignore tools: mappedTools.filter((tool) => tool.function.name === toolChoice.toolName), tool_choice: "any", }; default: { + //@ts-ignore const exhaustiveCheck = type satisfies never; throw new Error(`Unsupported tool choice type: ${exhaustiveCheck}`); } @@ -219,7 +234,7 @@ function mergePartialToolCalls(partialCalls: any[]) { return Object.values(mergedCallsByIndex); } -function processToolCall(toolCall: any): LanguageModelV1FunctionToolCall { +function processToolCall(toolCall: any): FunctionToolCall { if (toolCall.function && toolCall.id) { return { toolCallType: "function", @@ -242,7 +257,7 @@ function processToolCall(toolCall: any): LanguageModelV1FunctionToolCall { }; } -export function processToolCalls(output: any): LanguageModelV1FunctionToolCall[] { +export function processToolCalls(output: any): FunctionToolCall[] { // Check for OpenAI format tool calls first if (output.tool_calls && Array.isArray(output.tool_calls)) { return output.tool_calls.map((toolCall: any) => { diff --git a/packages/workers-ai-provider/src/workers-ai-embedding-model.ts b/packages/workers-ai-provider/src/workers-ai-embedding-model.ts index f3cc4f56..24366f85 100644 --- a/packages/workers-ai-provider/src/workers-ai-embedding-model.ts +++ b/packages/workers-ai-provider/src/workers-ai-embedding-model.ts @@ -1,4 +1,4 @@ -import { TooManyEmbeddingValuesForCallError, type EmbeddingModelV1 } from "@ai-sdk/provider"; +import { TooManyEmbeddingValuesForCallError, type EmbeddingModelV2 } from "@ai-sdk/provider"; import type { StringLike } from "./utils"; import type { EmbeddingModels } from "./workersai-models"; @@ -19,12 +19,12 @@ export type WorkersAIEmbeddingSettings = { [key: string]: StringLike; }; -export class WorkersAIEmbeddingModel implements EmbeddingModelV1 { +export class WorkersAIEmbeddingModel implements EmbeddingModelV2 { /** - * Semantic version of the {@link EmbeddingModelV1} specification implemented + * Semantic version of the {@link EmbeddingModelV2} specification implemented * by this class. It never changes. */ - readonly specificationVersion = "v1"; + readonly specificationVersion = "v2"; readonly modelId: EmbeddingModels; private readonly config: WorkersAIEmbeddingConfig; private readonly settings: WorkersAIEmbeddingSettings; @@ -58,8 +58,8 @@ export class WorkersAIEmbeddingModel implements EmbeddingModelV1 { async doEmbed({ values, - }: Parameters["doEmbed"]>[0]): Promise< - Awaited["doEmbed"]>> + }: Parameters["doEmbed"]>[0]): Promise< + Awaited["doEmbed"]>> > { if (values.length > this.maxEmbeddingsPerCall) { throw new TooManyEmbeddingValuesForCallError({ diff --git a/packages/workers-ai-provider/src/workersai-chat-language-model.ts b/packages/workers-ai-provider/src/workersai-chat-language-model.ts index d077e254..010a8bfd 100644 --- a/packages/workers-ai-provider/src/workersai-chat-language-model.ts +++ b/packages/workers-ai-provider/src/workersai-chat-language-model.ts @@ -1,16 +1,17 @@ import { - type LanguageModelV1, - type LanguageModelV1CallWarning, - type LanguageModelV1StreamPart, - UnsupportedFunctionalityError, + type LanguageModelV2, + type LanguageModelV2CallWarning, + type LanguageModelV2Content, + type LanguageModelV2StreamPart, } from "@ai-sdk/provider"; +import { generateId } from "@ai-sdk/provider-utils"; import { convertToWorkersAIChatMessages } from "./convert-to-workersai-chat-messages"; import type { WorkersAIChatSettings } from "./workersai-chat-settings"; import type { TextGenerationModels } from "./workersai-models"; import { mapWorkersAIUsage } from "./map-workersai-usage"; -import { getMappedStream } from "./streaming"; -import { lastMessageWasUser, prepareToolsAndToolChoice, processToolCalls } from "./utils"; +import { lastMessageWasUser } from "./utils"; +// import { getMappedStream } from "./streaming"; type WorkersAIChatConfig = { provider: string; @@ -18,15 +19,20 @@ type WorkersAIChatConfig = { gateway?: GatewayOptions; }; -export class WorkersAIChatLanguageModel implements LanguageModelV1 { - readonly specificationVersion = "v1"; +export class WorkersAIChatLanguageModel implements LanguageModelV2 { + readonly specificationVersion = "v2"; readonly defaultObjectGenerationMode = "json"; readonly modelId: TextGenerationModels; readonly settings: WorkersAIChatSettings; + readonly supportedUrls = { + 'image/*': [/^https?:\/\/.*$/], + }; + private readonly config: WorkersAIChatConfig; + constructor( modelId: TextGenerationModels, settings: WorkersAIChatSettings, @@ -41,18 +47,18 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 { return this.config.provider; } - private getArgs({ - mode, - maxTokens, + private async getArgs({ + maxOutputTokens, temperature, topP, frequencyPenalty, presencePenalty, seed, - }: Parameters[0]) { - const type = mode.type; - - const warnings: LanguageModelV1CallWarning[] = []; + tools, + toolChoice, + responseFormat, + }: Parameters[0]) { + const warnings: LanguageModelV2CallWarning[] = []; if (frequencyPenalty != null) { warnings.push({ @@ -68,72 +74,43 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 { }); } - const baseArgs = { - // model id: - model: this.modelId, + if (responseFormat != null && responseFormat.type !== 'text') { + warnings.push({ + type: 'unsupported-setting', + setting: 'responseFormat', + details: 'JSON response format is not supported.', + }); + } - // model specific settings: - safe_prompt: this.settings.safePrompt, - // standardized settings: - max_tokens: maxTokens, - temperature, - top_p: topP, - random_seed: seed, + return { + args: { + // model id: + model: this.modelId, + + // model specific settings: + safe_prompt: this.settings.safePrompt, + + // standardized settings: + max_tokens: maxOutputTokens, + temperature, + top_p: topP, + random_seed: seed, + // Don't include response_format for Cloudflare AI + // response_format: responseFormat?.type ?? "text", + + // tools + tools: tools, + tool_choice: toolChoice, + }, + warnings, }; - - switch (type) { - case "regular": { - return { - args: { ...baseArgs, ...prepareToolsAndToolChoice(mode) }, - warnings, - }; - } - - case "object-json": { - return { - args: { - ...baseArgs, - response_format: { - type: "json_schema", - json_schema: mode.schema, - }, - tools: undefined, - }, - warnings, - }; - } - - case "object-tool": { - return { - args: { - ...baseArgs, - tool_choice: "any", - tools: [{ type: "function", function: mode.tool }], - }, - warnings, - }; - } - - // @ts-expect-error - this is unreachable code - // TODO: fixme - case "object-grammar": { - throw new UnsupportedFunctionalityError({ - functionality: "object-grammar mode", - }); - } - - default: { - const exhaustiveCheck = type satisfies never; - throw new Error(`Unsupported type: ${exhaustiveCheck}`); - } - } } async doGenerate( - options: Parameters[0], - ): Promise>> { - const { args, warnings } = this.getArgs(options); + options: Parameters[0], + ): Promise>> { + const { args, warnings } = await this.getArgs(options); const { gateway, safePrompt, ...passthroughOptions } = this.settings; @@ -168,103 +145,336 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 { throw new Error("This shouldn't happen"); } + const content: Array = []; + const text = output.response + if (!!text && text.length > 0) { + content.push({ + type: "text", + text, + }); + } + + + // tool calls + for (const toolCall of output.tool_calls ?? []) { + content.push({ + type: 'tool-call' as const, + toolCallId: generateId(), + toolName: toolCall.name, + input: JSON.parse(toolCall.arguments as string), + }); + } + + console.log('Workers AI response:', output); + + return { - text: - typeof output.response === "object" && output.response !== null - ? JSON.stringify(output.response) // ai-sdk expects a string here - : output.response, - toolCalls: processToolCalls(output), + content, + // text: + // typeof output.response === "object" && output.response !== null + // ? JSON.stringify(output.response) // ai-sdk expects a string here + // : output.response, + // toolCalls: processToolCalls(output), finishReason: "stop", // TODO: mapWorkersAIFinishReason(response.finish_reason), - rawCall: { rawPrompt: messages, rawSettings: args }, + // rawCall: { rawPrompt: messages, rawSettings: args }, usage: mapWorkersAIUsage(output), warnings, }; } + // async doStream( + // options: Parameters[0], + // ): Promise>> { + // const { args, warnings } = await this.getArgs(options); + // + // // Extract image from messages if present + // const { messages, images } = convertToWorkersAIChatMessages(options.prompt); + // + // // [1] When the latest message is not a tool response, we use the regular generate function + // // and simulate it as a streamed response in order to satisfy the AI SDK's interface for + // // doStream... + // if (args.tools?.length && lastMessageWasUser(messages)) { + // const response = await this.doGenerate(options); + // + // if (response instanceof ReadableStream) { + // throw new Error("This shouldn't happen"); + // } + // + // return { + // stream: new ReadableStream({ + // async start(controller) { + // if (response.text) { + // controller.enqueue({ + // type: "text-delta", + // textDelta: response.text, + // }); + // } + // if (response.toolCalls) { + // for (const toolCall of response.toolCalls) { + // controller.enqueue({ + // type: "tool-call", + // ...toolCall, + // }); + // } + // } + // controller.enqueue({ + // type: "finish", + // finishReason: "stop", + // usage: response.usage, + // }); + // controller.close(); + // }, + // }), + // rawCall: { rawPrompt: messages, rawSettings: args }, + // warnings, + // }; + // } + // + // // [2] ...otherwise, we just proceed as normal and stream the response directly from the remote model. + // const { gateway, ...passthroughOptions } = this.settings; + // + // // TODO: support for multiple images + // if (images.length !== 0 && images.length !== 1) { + // throw new Error("Multiple images are not yet supported as input"); + // } + // + // const imagePart = images[0]; + // + // const response = await this.config.binding.run( + // args.model, + // { + // messages: messages, + // max_tokens: args.max_tokens, + // stream: true, + // temperature: args.temperature, + // tools: args.tools, + // top_p: args.top_p, + // // Convert Uint8Array to Array of integers for Llama 3.2 Vision model + // // TODO: maybe use the base64 string version? + // ...(imagePart ? { image: Array.from(imagePart.image) } : {}), + // // @ts-expect-error response_format not yet added to types + // response_format: args.response_format, + // }, + // { gateway: this.config.gateway ?? gateway, ...passthroughOptions }, + // ); + // + // if (!(response instanceof ReadableStream)) { + // throw new Error("This shouldn't happen"); + // } + // + // return { + // stream: getMappedStream(new Response(response)), + // rawCall: { rawPrompt: messages, rawSettings: args }, + // warnings, + // }; + // } async doStream( - options: Parameters[0], - ): Promise>> { - const { args, warnings } = this.getArgs(options); - - // Extract image from messages if present + options: Parameters[0], + ): Promise>> { + console.log('doStream called with options:', options); + const { args, warnings } = await this.getArgs(options); const { messages, images } = convertToWorkersAIChatMessages(options.prompt); - // [1] When the latest message is not a tool response, we use the regular generate function - // and simulate it as a streamed response in order to satisfy the AI SDK's interface for - // doStream... + // fallback: simulate streaming with a full generation call if (args.tools?.length && lastMessageWasUser(messages)) { + console.log('Fallback to full generation call for streaming'); const response = await this.doGenerate(options); - if (response instanceof ReadableStream) { - throw new Error("This shouldn't happen"); - } + if (response instanceof ReadableStream) throw new Error('Unexpected stream'); return { - stream: new ReadableStream({ + stream: new ReadableStream({ async start(controller) { - if (response.text) { - controller.enqueue({ - type: "text-delta", - textDelta: response.text, - }); + controller.enqueue({ type: 'stream-start', warnings }); + console.log('Starting fallback stream', args); + console.log('Response from fallback stream:', response); + + if (response.content) { + // Convert content to stream parts + for (const contentPart of response.content) { + if (contentPart.type === 'text') { + controller.enqueue({ + type: 'text-start', + id: generateId(), + }); + controller.enqueue({ + type: 'text-delta', + id: generateId(), + delta: contentPart.text, + }); + controller.enqueue({ + type: 'text-end', + id: generateId(), + }); + } else if (contentPart.type === 'tool-call') { + controller.enqueue(contentPart); + } + } } + + + //@ts-ignore if (response.toolCalls) { + + //@ts-ignore for (const toolCall of response.toolCalls) { controller.enqueue({ - type: "tool-call", - ...toolCall, + type: 'tool-call', + toolCallId: toolCall.id ?? crypto.randomUUID(), + toolName: toolCall.function.name, + input: JSON.parse(toolCall.function.arguments), }); } } + controller.enqueue({ - type: "finish", - finishReason: "stop", - usage: response.usage, + type: 'finish', + finishReason: 'stop', + usage: response.usage ?? { + inputTokens: undefined, + outputTokens: undefined, + totalTokens: undefined, + }, }); + controller.close(); }, }), - rawCall: { rawPrompt: messages, rawSettings: args }, - warnings, + request: { body: args }, + response: {}, }; } - // [2] ...otherwise, we just proceed as normal and stream the response directly from the remote model. - const { gateway, ...passthroughOptions } = this.settings; - - // TODO: support for multiple images - if (images.length !== 0 && images.length !== 1) { - throw new Error("Multiple images are not yet supported as input"); - } - + // real streaming flow from Workers AI const imagePart = images[0]; + console.log('Starting Workers AI stream with args:', args, 'and imagePart:', imagePart); + console.log('Messages being sent:', JSON.stringify(messages, null, 2)); + + const runOptions = { + messages, + max_tokens: args.max_tokens, + stream: true, + temperature: args.temperature, + tools: args.tools, + top_p: args.top_p, + ...(imagePart ? { image: Array.from(imagePart.image) } : {}), + // Don't include response_format for streaming + }; + + console.log('Full run options:', JSON.stringify(runOptions, null, 2)); + const response = await this.config.binding.run( args.model, + runOptions as any, { - messages: messages, - max_tokens: args.max_tokens, - stream: true, - temperature: args.temperature, - tools: args.tools, - top_p: args.top_p, - // Convert Uint8Array to Array of integers for Llama 3.2 Vision model - // TODO: maybe use the base64 string version? - ...(imagePart ? { image: Array.from(imagePart.image) } : {}), - // @ts-expect-error response_format not yet added to types - response_format: args.response_format, + gateway: this.config.gateway ?? this.settings.gateway, }, - { gateway: this.config.gateway ?? gateway, ...passthroughOptions }, ); + console.log('Workers AI response type:', typeof response, response instanceof ReadableStream, response instanceof Response); + + // Ensure we have a ReadableStream if (!(response instanceof ReadableStream)) { - throw new Error("This shouldn't happen"); + throw new Error('Expected a ReadableStream from Workers AI'); } + console.log('Workers AI stream', args); + + // Track text state across transform calls + let textStarted = false; + let textId: string | null = null; + let lastUsage: any = null; + return { - stream: getMappedStream(new Response(response)), - rawCall: { rawPrompt: messages, rawSettings: args }, - warnings, + stream: response.pipeThrough( + new TransformStream({ + async start(controller) { + console.log('[DIRECT STREAM] Starting Workers AI stream'); + controller.enqueue({ type: 'stream-start', warnings }); + }, + async transform(chunk, controller) { + const text = new TextDecoder().decode(chunk); + console.log('[DIRECT STREAM] Received chunk:', text); + + // Handle SSE format + const lines = text.split('\n'); + for (const line of lines) { + if (line.startsWith('data: ')) { + const dataStr = line.slice(6); // Remove 'data: ' prefix + + // Skip [DONE] marker + if (dataStr === '[DONE]') { + console.log('[DIRECT STREAM] Received [DONE]'); + continue; + } + + try { + const data = JSON.parse(dataStr); + console.log('[DIRECT STREAM] Parsed JSON:', data); + + // Check for errors first + if (data.errors && data.errors.length > 0) { + console.error('[DIRECT STREAM] Cloudflare AI Error:', data.errors); + // Still try to process any partial response + } + + if (data.response && data.response !== null && data.response !== '') { + // Start text if not started + if (!textStarted) { + textId = generateId(); + controller.enqueue({ + type: 'text-start', + id: textId, + }); + textStarted = true; + } + // Send the text delta + controller.enqueue({ + type: 'text-delta', + id: textId!, + delta: data.response, + }); + } + + // Handle usage data + if (data.usage) { + console.log('[DIRECT STREAM] Usage:', data.usage); + lastUsage = data.usage; + } + } catch (e) { + console.log('[DIRECT STREAM] Failed to parse data:', dataStr, e); + } + } + } + }, + flush(controller) { + console.log('[DIRECT STREAM] Stream finished'); + // Only send text-end if we started text + if (textStarted && textId) { + controller.enqueue({ + type: 'text-end', + id: textId, + }); + } + controller.enqueue({ + type: 'finish', + finishReason: 'stop', + usage: lastUsage ? { + inputTokens: lastUsage.prompt_tokens || 0, + outputTokens: lastUsage.completion_tokens || 0, + totalTokens: lastUsage.total_tokens || 0, + } : { + inputTokens: 0, + outputTokens: 0, + totalTokens: 0, + }, + }); + }, + }) + ), + request: { body: args }, + response: {}, }; } } diff --git a/packages/workers-ai-provider/src/workersai-image-model.ts b/packages/workers-ai-provider/src/workersai-image-model.ts index a3e44ce1..5011e4c8 100644 --- a/packages/workers-ai-provider/src/workersai-image-model.ts +++ b/packages/workers-ai-provider/src/workersai-image-model.ts @@ -1,10 +1,10 @@ -import type { ImageModelV1, ImageModelV1CallWarning } from "@ai-sdk/provider"; +import type { ImageModelV2, ImageModelV2CallWarning } from "@ai-sdk/provider"; import type { WorkersAIImageConfig } from "./workersai-image-config"; import type { WorkersAIImageSettings } from "./workersai-image-settings"; import type { ImageGenerationModels } from "./workersai-models"; -export class WorkersAIImageModel implements ImageModelV1 { - readonly specificationVersion = "v1"; +export class WorkersAIImageModel implements ImageModelV2 { + readonly specificationVersion = "v2"; get maxImagesPerCall(): number { return this.settings.maxImagesPerCall ?? 1; @@ -17,7 +17,7 @@ export class WorkersAIImageModel implements ImageModelV1 { readonly modelId: ImageGenerationModels, readonly settings: WorkersAIImageSettings, readonly config: WorkersAIImageConfig, - ) {} + ) { } async doGenerate({ prompt, @@ -27,12 +27,12 @@ export class WorkersAIImageModel implements ImageModelV1 { seed, // headers, // abortSignal, - }: Parameters[0]): Promise< - Awaited> + }: Parameters[0]): Promise< + Awaited> > { const { width, height } = getDimensionsFromSizeString(size); - const warnings: Array = []; + const warnings: Array = []; if (aspectRatio != null) { warnings.push({ diff --git a/packages/workers-ai-provider/src/workersai-models.ts b/packages/workers-ai-provider/src/workersai-models.ts index 69870f05..acf9daca 100644 --- a/packages/workers-ai-provider/src/workersai-models.ts +++ b/packages/workers-ai-provider/src/workersai-models.ts @@ -17,3 +17,5 @@ export type ImageGenerationModels = value2key; export type EmbeddingModels = value2key; type value2key = { [K in keyof T]: T[K] extends V ? K : never }[keyof T]; + + diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 1b2de2f4..4fe209a5 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -824,7 +824,7 @@ importers: version: 1.12.0 '@workos-inc/node': specifier: ^7.51.0 - version: 7.51.0(express@5.1.0) + version: 7.52.0(express@5.1.0) agents: specifier: ^0.0.93 version: 0.0.93(@cloudflare/workers-types@4.20250525.0)(react@19.1.0) @@ -1574,8 +1574,11 @@ importers: packages/workers-ai-provider: dependencies: '@ai-sdk/provider': - specifier: ^1.1.3 - version: 1.1.3 + specifier: 2.0.0-beta.1 + version: 2.0.0-beta.1 + '@ai-sdk/provider-utils': + specifier: 2.2.8 + version: 2.2.8(zod@3.25.28) devDependencies: '@cloudflare/workers-types': specifier: ^4.20250525.0 @@ -1624,6 +1627,10 @@ packages: resolution: {integrity: sha512-qZMxYJ0qqX/RfnuIaab+zp8UAeJn/ygXXAffR5I4N0n1IrvA6qBsjc8hXLmBiMV2zoXlifkacF7sEFnYnjBcqg==} engines: {node: '>=18'} + '@ai-sdk/provider@2.0.0-beta.1': + resolution: {integrity: sha512-Z8SPncMtS3RsoXITmT7NVwrAq6M44dmw0DoUOYJqNNtCu8iMWuxB8Nxsoqpa0uEEy9R1V1ZThJAXTYgjTUxl3w==} + engines: {node: '>=18'} + '@ai-sdk/react@1.2.12': resolution: {integrity: sha512-jK1IZZ22evPZoQW3vlkZ7wvjYGYF+tRBKXtrcolduIkQ/m/sOAVcVeVDUDvh1T91xCnWCdUGCPZg2avZ90mv3g==} engines: {node: '>=18'} @@ -3227,8 +3234,8 @@ packages: '@vitest/utils@3.1.4': resolution: {integrity: sha512-yriMuO1cfFhmiGc8ataN51+9ooHRuURdfAZfwFd3usWynjzpLslZdYnRegTv32qdgtJTsj15FoeZe2g15fY1gg==} - '@workos-inc/node@7.51.0': - resolution: {integrity: sha512-08y449PDWdPjziyIxal+Bap2iUpcQ989b9VUg3DiJdv95xC6WCVtaFVAG7FLr8aJmX2DaeCdrBpMR3HTowrxSQ==} + '@workos-inc/node@7.52.0': + resolution: {integrity: sha512-BTQsXlQQ1i8URg4rn9yBNGp840Je7ZAg/i6VpEQ13pLbHnJpxNobX3gJDoxRiSb+KR0jMP19KZ4X8si7E1xH/g==} engines: {node: '>=16'} '@yarnpkg/lockfile@1.1.0': @@ -6013,6 +6020,10 @@ snapshots: dependencies: json-schema: 0.4.0 + '@ai-sdk/provider@2.0.0-beta.1': + dependencies: + json-schema: 0.4.0 + '@ai-sdk/react@1.2.12(react@19.1.0)(zod@3.25.28)': dependencies: '@ai-sdk/provider-utils': 2.2.8(zod@3.25.28) @@ -7686,7 +7697,7 @@ snapshots: loupe: 3.1.3 tinyrainbow: 2.0.0 - '@workos-inc/node@7.51.0(express@5.1.0)': + '@workos-inc/node@7.52.0(express@5.1.0)': dependencies: iron-session: 6.3.1(express@5.1.0) jose: 5.6.3