diff --git a/src/worker.mjs b/src/worker.mjs index 39d13ba..35621ed 100644 --- a/src/worker.mjs +++ b/src/worker.mjs @@ -1,67 +1,75 @@ import { Buffer } from "node:buffer"; export default { - async fetch (request) { - if (request.method === "OPTIONS") { - return handleOPTIONS(); - } - const errHandler = (err) => { - console.error(err); - return new Response(err.message, fixCors({ status: err.status ?? 500 })); - }; - try { - const auth = request.headers.get("Authorization"); - const apiKey = auth?.split(" ")[1]; - const assert = (success) => { - if (!success) { - throw new HttpError("The specified HTTP method is not allowed for the requested resource", 400); + async fetch(request, env) { + request.__env = env; + + if (request.method === "OPTIONS") { + return handleOPTIONS(); + } + + const errHandler = (err) => { + console.error(err); + return new Response(err.message, fixCors({ status: err.status ?? 500 })); + }; + + if (!isWhiteIp(env, request)) { + return errHandler(new HttpError("IP not allowed", 403)); + } + + try { + const auth = request.headers.get("Authorization"); + const apiKey = auth?.split(" ")[1]; + const assert = (success) => { + if (!success) { + throw new HttpError("The specified HTTP method is not allowed for the requested resource", 400); + } + }; + const { pathname } = new URL(request.url); + switch (true) { + case pathname.endsWith("/chat/completions"): + assert(request.method === "POST"); + return handleCompletions(await request.json(), apiKey, request) + .catch(errHandler); + case pathname.endsWith("/embeddings"): + assert(request.method === "POST"); + return handleEmbeddings(await request.json(), apiKey, request) + .catch(errHandler); + case pathname.endsWith("/models"): + assert(request.method === "GET"); + return handleModels(apiKey, request) + .catch(errHandler); + default: + throw new HttpError("404 Not Found", 404); + } + } catch (err) { + return errHandler(err); } - }; - const { pathname } = new URL(request.url); - switch (true) { - case pathname.endsWith("/chat/completions"): - assert(request.method === "POST"); - return handleCompletions(await request.json(), apiKey) - .catch(errHandler); - case pathname.endsWith("/embeddings"): - assert(request.method === "POST"); - return handleEmbeddings(await request.json(), apiKey) - .catch(errHandler); - case pathname.endsWith("/models"): - assert(request.method === "GET"); - return handleModels(apiKey) - .catch(errHandler); - default: - throw new HttpError("404 Not Found", 404); - } - } catch (err) { - return errHandler(err); } - } }; class HttpError extends Error { - constructor(message, status) { - super(message); - this.name = this.constructor.name; - this.status = status; - } + constructor(message, status) { + super(message); + this.name = this.constructor.name; + this.status = status; + } } const fixCors = ({ headers, status, statusText }) => { - headers = new Headers(headers); - headers.set("Access-Control-Allow-Origin", "*"); - return { headers, status, statusText }; + headers = new Headers(headers); + headers.set("Access-Control-Allow-Origin", "*"); + return { headers, status, statusText }; }; const handleOPTIONS = async () => { - return new Response(null, { - headers: { - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "*", - "Access-Control-Allow-Headers": "*", - } - }); + return new Response(null, { + headers: { + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "*", + "Access-Control-Allow-Headers": "*", + } + }); }; const BASE_URL = "https://generativelanguage.googleapis.com"; @@ -70,578 +78,599 @@ const API_VERSION = "v1beta"; // https://github.com/google-gemini/generative-ai-js/blob/cf223ff4a1ee5a2d944c53cddb8976136382bee6/src/requests/request.ts#L71 const API_CLIENT = "genai-js/0.21.0"; // npm view @google/generative-ai version const makeHeaders = (apiKey, more) => ({ - "x-goog-api-client": API_CLIENT, - ...(apiKey && { "x-goog-api-key": apiKey }), - ...more + "x-goog-api-client": API_CLIENT, + ...(apiKey && { "x-goog-api-key": apiKey }), + ...more }); -async function handleModels (apiKey) { - const response = await fetch(`${BASE_URL}/${API_VERSION}/models`, { - headers: makeHeaders(apiKey), - }); - let { body } = response; - if (response.ok) { - const { models } = JSON.parse(await response.text()); - body = JSON.stringify({ - object: "list", - data: models.map(({ name }) => ({ - id: name.replace("models/", ""), - object: "model", - created: 0, - owned_by: "", - })), - }, null, " "); - } - return new Response(body, fixCors(response)); +async function handleModels(apiKey, request) { + const response = await fetch(`${BASE_URL}/${API_VERSION}/models`, { + headers: makeHeaders(apiKey), + }); + let { body } = response; + if (response.ok) { + const { models } = JSON.parse(await response.text()); + body = JSON.stringify({ + object: "list", + data: models.map(({ name }) => ({ + id: name.replace("models/", ""), + object: "model", + created: 0, + owned_by: "", + })), + req: handleResponseBody(request), + }, null, " "); + } + return new Response(body, fixCors(response)); } const DEFAULT_EMBEDDINGS_MODEL = "text-embedding-004"; -async function handleEmbeddings (req, apiKey) { - if (typeof req.model !== "string") { - throw new HttpError("model is not specified", 400); - } - let model; - if (req.model.startsWith("models/")) { - model = req.model; - } else { - if (!req.model.startsWith("gemini-")) { - req.model = DEFAULT_EMBEDDINGS_MODEL; - } - model = "models/" + req.model; - } - if (!Array.isArray(req.input)) { - req.input = [ req.input ]; - } - const response = await fetch(`${BASE_URL}/${API_VERSION}/${model}:batchEmbedContents`, { - method: "POST", - headers: makeHeaders(apiKey, { "Content-Type": "application/json" }), - body: JSON.stringify({ - "requests": req.input.map(text => ({ - model, - content: { parts: { text } }, - outputDimensionality: req.dimensions, - })) - }) - }); - let { body } = response; - if (response.ok) { - const { embeddings } = JSON.parse(await response.text()); - body = JSON.stringify({ - object: "list", - data: embeddings.map(({ values }, index) => ({ - object: "embedding", - index, - embedding: values, - })), - model: req.model, - }, null, " "); - } - return new Response(body, fixCors(response)); +async function handleEmbeddings(req, apiKey, request) { + if (typeof req.model !== "string") { + throw new HttpError("model is not specified", 400); + } + let model; + if (req.model.startsWith("models/")) { + model = req.model; + } else { + if (!req.model.startsWith("gemini-")) { + req.model = DEFAULT_EMBEDDINGS_MODEL; + } + model = "models/" + req.model; + } + if (!Array.isArray(req.input)) { + req.input = [req.input]; + } + const response = await fetch(`${BASE_URL}/${API_VERSION}/${model}:batchEmbedContents`, { + method: "POST", + headers: makeHeaders(apiKey, { "Content-Type": "application/json" }), + body: JSON.stringify({ + "requests": req.input.map(text => ({ + model, + content: { parts: { text } }, + outputDimensionality: req.dimensions, + })) + }) + }); + let { body } = response; + if (response.ok) { + const { embeddings } = JSON.parse(await response.text()); + body = JSON.stringify({ + object: "list", + data: embeddings.map(({ values }, index) => ({ + object: "embedding", + index, + embedding: values, + })), + model: req.model, + req: handleResponseBody(request), + }, null, " "); + } + return new Response(body, fixCors(response)); } const DEFAULT_MODEL = "gemini-2.0-flash"; -async function handleCompletions (req, apiKey) { - let model = DEFAULT_MODEL; - switch (true) { - case typeof req.model !== "string": - break; - case req.model.startsWith("models/"): - model = req.model.substring(7); - break; - case req.model.startsWith("gemini-"): - case req.model.startsWith("gemma-"): - case req.model.startsWith("learnlm-"): - model = req.model; - } - let body = await transformRequest(req); - switch (true) { - case model.endsWith(":search"): - model = model.substring(0, model.length - 7); - // eslint-disable-next-line no-fallthrough - case req.model.endsWith("-search-preview"): - body.tools = body.tools || []; - body.tools.push({googleSearch: {}}); - } - const TASK = req.stream ? "streamGenerateContent" : "generateContent"; - let url = `${BASE_URL}/${API_VERSION}/models/${model}:${TASK}`; - if (req.stream) { url += "?alt=sse"; } - const response = await fetch(url, { - method: "POST", - headers: makeHeaders(apiKey, { "Content-Type": "application/json" }), - body: JSON.stringify(body), - }); - - body = response.body; - if (response.ok) { - let id = "chatcmpl-" + generateId(); //"chatcmpl-8pMMaqXMK68B3nyDBrapTDrhkHBQK"; - const shared = {}; - if (req.stream) { - body = response.body - .pipeThrough(new TextDecoderStream()) - .pipeThrough(new TransformStream({ - transform: parseStream, - flush: parseStreamFlush, - buffer: "", - shared, - })) - .pipeThrough(new TransformStream({ - transform: toOpenAiStream, - flush: toOpenAiStreamFlush, - streamIncludeUsage: req.stream_options?.include_usage, - model, id, last: [], - shared, - })) - .pipeThrough(new TextEncoderStream()); - } else { - body = await response.text(); - try { - body = JSON.parse(body); - if (!body.candidates) { - throw new Error("Invalid completion object"); +async function handleCompletions(req, apiKey, request) { + let model = DEFAULT_MODEL; + switch (true) { + case typeof req.model !== "string": + break; + case req.model.startsWith("models/"): + model = req.model.substring(7); + break; + case req.model.startsWith("gemini-"): + case req.model.startsWith("gemma-"): + case req.model.startsWith("learnlm-"): + model = req.model; + } + let body = await transformRequest(req); + switch (true) { + case model.endsWith(":search"): + model = model.substring(0, model.length - 7); + // eslint-disable-next-line no-fallthrough + case req.model.endsWith("-search-preview"): + body.tools = body.tools || []; + body.tools.push({ googleSearch: {} }); + } + const TASK = req.stream ? "streamGenerateContent" : "generateContent"; + let url = `${BASE_URL}/${API_VERSION}/models/${model}:${TASK}`; + if (req.stream) { url += "?alt=sse"; } + const response = await fetch(url, { + method: "POST", + headers: makeHeaders(apiKey, { "Content-Type": "application/json" }), + body: JSON.stringify(body), + }); + + body = response.body; + if (response.ok) { + let id = "chatcmpl-" + generateId(); //"chatcmpl-8pMMaqXMK68B3nyDBrapTDrhkHBQK"; + const shared = {}; + if (req.stream) { + body = response.body + .pipeThrough(new TextDecoderStream()) + .pipeThrough(new TransformStream({ + transform: parseStream, + flush: parseStreamFlush, + buffer: "", + shared, + })) + .pipeThrough(new TransformStream({ + transform: toOpenAiStream, + flush: toOpenAiStreamFlush, + streamIncludeUsage: req.stream_options?.include_usage, + model, id, last: [], + shared, + })) + .pipeThrough(new TextEncoderStream()); + } else { + body = await response.text(); + try { + body = { ...JSON.parse(body), req: handleResponseBody(request) }; + if (!body.candidates) { + throw new Error("Invalid completion object"); + } + } catch (err) { + console.error("Error parsing response:", err); + return new Response(body, fixCors(response)); // output as is + } + body = processCompletionsResponse(body, model, id); } - } catch (err) { - console.error("Error parsing response:", err); - return new Response(body, fixCors(response)); // output as is - } - body = processCompletionsResponse(body, model, id); } - } - return new Response(body, fixCors(response)); + return new Response(body, fixCors(response)); } const adjustProps = (schemaPart) => { - if (typeof schemaPart !== "object" || schemaPart === null) { - return; - } - if (Array.isArray(schemaPart)) { - schemaPart.forEach(adjustProps); - } else { - if (schemaPart.type === "object" && schemaPart.properties && schemaPart.additionalProperties === false) { - delete schemaPart.additionalProperties; - } - Object.values(schemaPart).forEach(adjustProps); - } + if (typeof schemaPart !== "object" || schemaPart === null) { + return; + } + if (Array.isArray(schemaPart)) { + schemaPart.forEach(adjustProps); + } else { + if (schemaPart.type === "object" && schemaPart.properties && schemaPart.additionalProperties === false) { + delete schemaPart.additionalProperties; + } + Object.values(schemaPart).forEach(adjustProps); + } }; const adjustSchema = (schema) => { - const obj = schema[schema.type]; - delete obj.strict; - return adjustProps(schema); + const obj = schema[schema.type]; + delete obj.strict; + return adjustProps(schema); }; const harmCategory = [ - "HARM_CATEGORY_HATE_SPEECH", - "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "HARM_CATEGORY_DANGEROUS_CONTENT", - "HARM_CATEGORY_HARASSMENT", - "HARM_CATEGORY_CIVIC_INTEGRITY", + "HARM_CATEGORY_HATE_SPEECH", + "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "HARM_CATEGORY_DANGEROUS_CONTENT", + "HARM_CATEGORY_HARASSMENT", + "HARM_CATEGORY_CIVIC_INTEGRITY", ]; const safetySettings = harmCategory.map(category => ({ - category, - threshold: "BLOCK_NONE", + category, + threshold: "BLOCK_NONE", })); const fieldsMap = { - frequency_penalty: "frequencyPenalty", - max_completion_tokens: "maxOutputTokens", - max_tokens: "maxOutputTokens", - n: "candidateCount", // not for streaming - presence_penalty: "presencePenalty", - seed: "seed", - stop: "stopSequences", - temperature: "temperature", - top_k: "topK", // non-standard - top_p: "topP", + frequency_penalty: "frequencyPenalty", + max_completion_tokens: "maxOutputTokens", + max_tokens: "maxOutputTokens", + n: "candidateCount", // not for streaming + presence_penalty: "presencePenalty", + seed: "seed", + stop: "stopSequences", + temperature: "temperature", + top_k: "topK", // non-standard + top_p: "topP", + response_modalities: "responseModalities", // ["TEXT","IMAGE"] }; const transformConfig = (req) => { - let cfg = {}; - //if (typeof req.stop === "string") { req.stop = [req.stop]; } // no need - for (let key in req) { - const matchedKey = fieldsMap[key]; - if (matchedKey) { - cfg[matchedKey] = req[key]; - } - } - if (req.response_format) { - switch (req.response_format.type) { - case "json_schema": - adjustSchema(req.response_format); - cfg.responseSchema = req.response_format.json_schema?.schema; - if (cfg.responseSchema && "enum" in cfg.responseSchema) { - cfg.responseMimeType = "text/x.enum"; - break; + let cfg = {}; + //if (typeof req.stop === "string") { req.stop = [req.stop]; } // no need + for (let key in req) { + const matchedKey = fieldsMap[key]; + if (matchedKey) { + cfg[matchedKey] = req[key]; } - // eslint-disable-next-line no-fallthrough - case "json_object": - cfg.responseMimeType = "application/json"; - break; - case "text": - cfg.responseMimeType = "text/plain"; - break; - default: - throw new HttpError("Unsupported response_format.type", 400); - } - } - return cfg; + } + if (req.response_format) { + switch (req.response_format.type) { + case "json_schema": + adjustSchema(req.response_format); + cfg.responseSchema = req.response_format.json_schema?.schema; + if (cfg.responseSchema && "enum" in cfg.responseSchema) { + cfg.responseMimeType = "text/x.enum"; + break; + } + // eslint-disable-next-line no-fallthrough + case "json_object": + cfg.responseMimeType = "application/json"; + break; + case "text": + cfg.responseMimeType = "text/plain"; + break; + default: + throw new HttpError("Unsupported response_format.type", 400); + } + } + + if (!req.response_modalities && req.model.includes('image')) { + cfg['response_modalities'] = ["TEXT", "IMAGE"] + } + return cfg; }; const parseImg = async (url) => { - let mimeType, data; - if (url.startsWith("http://") || url.startsWith("https://")) { - try { - const response = await fetch(url); - if (!response.ok) { - throw new Error(`${response.status} ${response.statusText} (${url})`); - } - mimeType = response.headers.get("content-type"); - data = Buffer.from(await response.arrayBuffer()).toString("base64"); - } catch (err) { - throw new Error("Error fetching image: " + err.toString()); - } - } else { - const match = url.match(/^data:(?.*?)(;base64)?,(?.*)$/); - if (!match) { - throw new HttpError("Invalid image data: " + url, 400); - } - ({ mimeType, data } = match.groups); - } - return { - inlineData: { - mimeType, - data, - }, - }; + let mimeType, data; + if (url.startsWith("http://") || url.startsWith("https://")) { + try { + const response = await fetch(url); + if (!response.ok) { + throw new Error(`${response.status} ${response.statusText} (${url})`); + } + mimeType = response.headers.get("content-type"); + data = Buffer.from(await response.arrayBuffer()).toString("base64"); + } catch (err) { + throw new Error("Error fetching image: " + err.toString()); + } + } else { + const match = url.match(/^data:(?.*?)(;base64)?,(?.*)$/); + if (!match) { + throw new HttpError("Invalid image data: " + url, 400); + } + ({ mimeType, data } = match.groups); + } + return { + inlineData: { + mimeType, + data, + }, + }; }; const transformFnResponse = ({ content, tool_call_id }, parts) => { - if (!parts.calls) { - throw new HttpError("No function calls found in the previous message", 400); - } - let response; - try { - response = JSON.parse(content); - } catch (err) { - console.error("Error parsing function response content:", err); - throw new HttpError("Invalid function response: " + content, 400); - } - if (typeof response !== "object" || response === null || Array.isArray(response)) { - response = { result: response }; - } - if (!tool_call_id) { - throw new HttpError("tool_call_id not specified", 400); - } - const { i, name } = parts.calls[tool_call_id] ?? {}; - if (!name) { - throw new HttpError("Unknown tool_call_id: " + tool_call_id, 400); - } - if (parts[i]) { - throw new HttpError("Duplicated tool_call_id: " + tool_call_id, 400); - } - parts[i] = { - functionResponse: { - id: tool_call_id.startsWith("call_") ? null : tool_call_id, - name, - response, - } - }; -}; - -const transformFnCalls = ({ tool_calls }) => { - const calls = {}; - const parts = tool_calls.map(({ function: { arguments: argstr, name }, id, type }, i) => { - if (type !== "function") { - throw new HttpError(`Unsupported tool_call type: "${type}"`, 400); + if (!parts.calls) { + throw new HttpError("No function calls found in the previous message", 400); } - let args; + let response; try { - args = JSON.parse(argstr); + response = JSON.parse(content); } catch (err) { - console.error("Error parsing function arguments:", err); - throw new HttpError("Invalid function arguments: " + argstr, 400); + console.error("Error parsing function response content:", err); + throw new HttpError("Invalid function response: " + content, 400); } - calls[id] = {i, name}; - return { - functionCall: { - id: id.startsWith("call_") ? null : id, - name, - args, - } + if (typeof response !== "object" || response === null || Array.isArray(response)) { + response = { result: response }; + } + if (!tool_call_id) { + throw new HttpError("tool_call_id not specified", 400); + } + const { i, name } = parts.calls[tool_call_id] ?? {}; + if (!name) { + throw new HttpError("Unknown tool_call_id: " + tool_call_id, 400); + } + if (parts[i]) { + throw new HttpError("Duplicated tool_call_id: " + tool_call_id, 400); + } + parts[i] = { + functionResponse: { + id: tool_call_id.startsWith("call_") ? null : tool_call_id, + name, + response, + } }; - }); - parts.calls = calls; - return parts; +}; + +const transformFnCalls = ({ tool_calls }) => { + const calls = {}; + const parts = tool_calls.map(({ function: { arguments: argstr, name }, id, type }, i) => { + if (type !== "function") { + throw new HttpError(`Unsupported tool_call type: "${type}"`, 400); + } + let args; + try { + args = JSON.parse(argstr); + } catch (err) { + console.error("Error parsing function arguments:", err); + throw new HttpError("Invalid function arguments: " + argstr, 400); + } + calls[id] = { i, name }; + return { + functionCall: { + id: id.startsWith("call_") ? null : id, + name, + args, + } + }; + }); + parts.calls = calls; + return parts; }; const transformMsg = async ({ content }) => { - const parts = []; - if (!Array.isArray(content)) { - // system, user: string - // assistant: string or null (Required unless tool_calls is specified.) - parts.push({ text: content }); + const parts = []; + if (!Array.isArray(content)) { + // system, user: string + // assistant: string or null (Required unless tool_calls is specified.) + parts.push({ text: content }); + return parts; + } + // user: + // An array of content parts with a defined type. + // Supported options differ based on the model being used to generate the response. + // Can contain text, image, or audio inputs. + for (const item of content) { + switch (item.type) { + case "text": + parts.push({ text: item.text }); + break; + case "image_url": + parts.push(await parseImg(item.image_url.url)); + break; + case "input_audio": + parts.push({ + inlineData: { + mimeType: "audio/" + item.input_audio.format, + data: item.input_audio.data, + } + }); + break; + default: + throw new HttpError(`Unknown "content" item type: "${item.type}"`, 400); + } + } + if (content.every(item => item.type === "image_url")) { + parts.push({ text: "" }); // to avoid "Unable to submit request because it must have a text parameter" + } return parts; - } - // user: - // An array of content parts with a defined type. - // Supported options differ based on the model being used to generate the response. - // Can contain text, image, or audio inputs. - for (const item of content) { - switch (item.type) { - case "text": - parts.push({ text: item.text }); - break; - case "image_url": - parts.push(await parseImg(item.image_url.url)); - break; - case "input_audio": - parts.push({ - inlineData: { - mimeType: "audio/" + item.input_audio.format, - data: item.input_audio.data, - } - }); - break; - default: - throw new HttpError(`Unknown "content" item type: "${item.type}"`, 400); - } - } - if (content.every(item => item.type === "image_url")) { - parts.push({ text: "" }); // to avoid "Unable to submit request because it must have a text parameter" - } - return parts; }; const transformMessages = async (messages) => { - if (!messages) { return; } - const contents = []; - let system_instruction; - for (const item of messages) { - switch (item.role) { - case "system": - system_instruction = { parts: await transformMsg(item) }; - continue; - case "tool": - // eslint-disable-next-line no-case-declarations - let { role, parts } = contents[contents.length - 1] ?? {}; - if (role !== "function") { - const calls = parts?.calls; - parts = []; parts.calls = calls; - contents.push({ - role: "function", // ignored - parts - }); + if (!messages) { return; } + const contents = []; + let system_instruction; + for (const item of messages) { + switch (item.role) { + case "system": + system_instruction = { parts: await transformMsg(item) }; + continue; + case "tool": + // eslint-disable-next-line no-case-declarations + let { role, parts } = contents[contents.length - 1] ?? {}; + if (role !== "function") { + const calls = parts?.calls; + parts = []; parts.calls = calls; + contents.push({ + role: "function", // ignored + parts + }); + } + transformFnResponse(item, parts); + continue; + case "assistant": + item.role = "model"; + break; + case "user": + break; + default: + throw new HttpError(`Unknown message role: "${item.role}"`, 400); } - transformFnResponse(item, parts); - continue; - case "assistant": - item.role = "model"; - break; - case "user": - break; - default: - throw new HttpError(`Unknown message role: "${item.role}"`, 400); - } - contents.push({ - role: item.role, - parts: item.tool_calls ? transformFnCalls(item) : await transformMsg(item) - }); - } - if (system_instruction) { - if (!contents[0]?.parts.some(part => part.text)) { - contents.unshift({ role: "user", parts: { text: " " } }); - } - } - //console.info(JSON.stringify(contents, 2)); - return { system_instruction, contents }; + contents.push({ + role: item.role, + parts: item.tool_calls ? transformFnCalls(item) : await transformMsg(item) + }); + } + if (system_instruction) { + if (!contents[0]?.parts.some(part => part.text)) { + contents.unshift({ role: "user", parts: { text: " " } }); + } + } + //console.info(JSON.stringify(contents, 2)); + return { system_instruction, contents }; }; const transformTools = (req) => { - let tools, tool_config; - if (req.tools) { - const funcs = req.tools.filter(tool => tool.type === "function"); - funcs.forEach(adjustSchema); - tools = [{ function_declarations: funcs.map(schema => schema.function) }]; - } - if (req.tool_choice) { - const allowed_function_names = req.tool_choice?.type === "function" ? [ req.tool_choice?.function?.name ] : undefined; - if (allowed_function_names || typeof req.tool_choice === "string") { - tool_config = { - function_calling_config: { - mode: allowed_function_names ? "ANY" : req.tool_choice.toUpperCase(), - allowed_function_names + let tools, tool_config; + if (req.tools) { + const funcs = req.tools.filter(tool => tool.type === "function"); + funcs.forEach(adjustSchema); + tools = [{ function_declarations: funcs.map(schema => schema.function) }]; + } + if (req.tool_choice) { + const allowed_function_names = req.tool_choice?.type === "function" ? [req.tool_choice?.function?.name] : undefined; + if (allowed_function_names || typeof req.tool_choice === "string") { + tool_config = { + function_calling_config: { + mode: allowed_function_names ? "ANY" : req.tool_choice.toUpperCase(), + allowed_function_names + } + }; } - }; } - } - return { tools, tool_config }; + return { tools, tool_config }; }; const transformRequest = async (req) => ({ - ...await transformMessages(req.messages), - safetySettings, - generationConfig: transformConfig(req), - ...transformTools(req), + ...await transformMessages(req.messages), + safetySettings, + generationConfig: transformConfig(req), + ...transformTools(req), }); const generateId = () => { - const characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; - const randomChar = () => characters[Math.floor(Math.random() * characters.length)]; - return Array.from({ length: 29 }, randomChar).join(""); + const characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + const randomChar = () => characters[Math.floor(Math.random() * characters.length)]; + return Array.from({ length: 29 }, randomChar).join(""); }; const reasonsMap = { //https://ai.google.dev/api/rest/v1/GenerateContentResponse#finishreason - //"FINISH_REASON_UNSPECIFIED": // Default value. This value is unused. - "STOP": "stop", - "MAX_TOKENS": "length", - "SAFETY": "content_filter", - "RECITATION": "content_filter", - //"OTHER": "OTHER", + //"FINISH_REASON_UNSPECIFIED": // Default value. This value is unused. + "STOP": "stop", + "MAX_TOKENS": "length", + "SAFETY": "content_filter", + "RECITATION": "content_filter", + //"OTHER": "OTHER", }; const SEP = "\n\n|>"; const transformCandidates = (key, cand) => { - const message = { role: "assistant", content: [] }; - for (const part of cand.content?.parts ?? []) { - if (part.functionCall) { - const fc = part.functionCall; - message.tool_calls = message.tool_calls ?? []; - message.tool_calls.push({ - id: fc.id ?? "call_" + generateId(), - type: "function", - function: { - name: fc.name, - arguments: JSON.stringify(fc.args), + const message = { role: "assistant", content: [] }; + for (const part of cand.content?.parts ?? []) { + if (part.functionCall) { + const fc = part.functionCall; + message.tool_calls = message.tool_calls ?? []; + message.tool_calls.push({ + id: fc.id ?? "call_" + generateId(), + type: "function", + function: { + name: fc.name, + arguments: JSON.stringify(fc.args), + } + }); + } else if (part.inlineData) { + message.inlineData = part.inlineData; + } else { + message.content.push(part.text); } - }); - } else { - message.content.push(part.text); - } - } - message.content = message.content.join(SEP) || null; - return { - index: cand.index || 0, // 0-index is absent in new -002 models response - [key]: message, - logprobs: null, - finish_reason: message.tool_calls ? "tool_calls" : reasonsMap[cand.finishReason] || cand.finishReason, - //original_finish_reason: cand.finishReason, - }; + } + message.content = message.content.join(SEP) || null; + return { + index: cand.index || 0, // 0-index is absent in new -002 models response + [key]: message, + logprobs: null, + finish_reason: message.tool_calls ? "tool_calls" : reasonsMap[cand.finishReason] || cand.finishReason, + //original_finish_reason: cand.finishReason, + }; }; const transformCandidatesMessage = transformCandidates.bind(null, "message"); const transformCandidatesDelta = transformCandidates.bind(null, "delta"); const transformUsage = (data) => ({ - completion_tokens: data.candidatesTokenCount, - prompt_tokens: data.promptTokenCount, - total_tokens: data.totalTokenCount + completion_tokens: data.candidatesTokenCount, + prompt_tokens: data.promptTokenCount, + total_tokens: data.totalTokenCount }); const checkPromptBlock = (choices, promptFeedback, key) => { - if (choices.length) { return; } - if (promptFeedback?.blockReason) { - console.log("Prompt block reason:", promptFeedback.blockReason); - if (promptFeedback.blockReason === "SAFETY") { - promptFeedback.safetyRatings - .filter(r => r.blocked) - .forEach(r => console.log(r)); - } - choices.push({ - index: 0, - [key]: null, - finish_reason: "content_filter", - //original_finish_reason: data.promptFeedback.blockReason, - }); - } - return true; + if (choices.length) { return; } + if (promptFeedback?.blockReason) { + console.log("Prompt block reason:", promptFeedback.blockReason); + if (promptFeedback.blockReason === "SAFETY") { + promptFeedback.safetyRatings + .filter(r => r.blocked) + .forEach(r => console.log(r)); + } + choices.push({ + index: 0, + [key]: null, + finish_reason: "content_filter", + //original_finish_reason: data.promptFeedback.blockReason, + }); + } + return true; }; const processCompletionsResponse = (data, model, id) => { - const obj = { - id, - choices: data.candidates.map(transformCandidatesMessage), - created: Math.floor(Date.now()/1000), - model: data.modelVersion ?? model, - //system_fingerprint: "fp_69829325d0", - object: "chat.completion", - usage: data.usageMetadata && transformUsage(data.usageMetadata), - }; - if (obj.choices.length === 0 ) { - checkPromptBlock(obj.choices, data.promptFeedback, "message"); - } - return JSON.stringify(obj); + const obj = { + id, + choices: data.candidates.map(transformCandidatesMessage), + created: Math.floor(Date.now() / 1000), + model: data.modelVersion ?? model, + //system_fingerprint: "fp_69829325d0", + object: "chat.completion", + usage: data.usageMetadata && transformUsage(data.usageMetadata), + req: data.req, + }; + if (obj.choices.length === 0) { + checkPromptBlock(obj.choices, data.promptFeedback, "message"); + } + return JSON.stringify(obj); }; const responseLineRE = /^data: (.*)(?:\n\n|\r\r|\r\n\r\n)/; -function parseStream (chunk, controller) { - this.buffer += chunk; - do { - const match = this.buffer.match(responseLineRE); - if (!match) { break; } - controller.enqueue(match[1]); - this.buffer = this.buffer.substring(match[0].length); - } while (true); // eslint-disable-line no-constant-condition +function parseStream(chunk, controller) { + this.buffer += chunk; + do { + const match = this.buffer.match(responseLineRE); + if (!match) { break; } + controller.enqueue(match[1]); + this.buffer = this.buffer.substring(match[0].length); + } while (true); // eslint-disable-line no-constant-condition } -function parseStreamFlush (controller) { - if (this.buffer) { - console.error("Invalid data:", this.buffer); - controller.enqueue(this.buffer); - this.shared.is_buffers_rest = true; - } +function parseStreamFlush(controller) { + if (this.buffer) { + console.error("Invalid data:", this.buffer); + controller.enqueue(this.buffer); + this.shared.is_buffers_rest = true; + } } const delimiter = "\n\n"; const sseline = (obj) => { - obj.created = Math.floor(Date.now()/1000); - return "data: " + JSON.stringify(obj) + delimiter; + obj.created = Math.floor(Date.now() / 1000); + return "data: " + JSON.stringify(obj) + delimiter; }; -function toOpenAiStream (line, controller) { - let data; - try { - data = JSON.parse(line); - if (!data.candidates) { - throw new Error("Invalid completion chunk object"); - } - } catch (err) { - console.error("Error parsing response:", err); - if (!this.shared.is_buffers_rest) { line =+ delimiter; } - controller.enqueue(line); // output as is - return; - } - const obj = { - id: this.id, - choices: data.candidates.map(transformCandidatesDelta), - //created: Math.floor(Date.now()/1000), - model: data.modelVersion ?? this.model, - //system_fingerprint: "fp_69829325d0", - object: "chat.completion.chunk", - usage: data.usageMetadata && this.streamIncludeUsage ? null : undefined, - }; - if (checkPromptBlock(obj.choices, data.promptFeedback, "delta")) { - controller.enqueue(sseline(obj)); - return; - } - console.assert(data.candidates.length === 1, "Unexpected candidates count: %d", data.candidates.length); - const cand = obj.choices[0]; - cand.index = cand.index || 0; // absent in new -002 models response - const finish_reason = cand.finish_reason; - cand.finish_reason = null; - if (!this.last[cand.index]) { // first - controller.enqueue(sseline({ - ...obj, - choices: [{ ...cand, tool_calls: undefined, delta: { role: "assistant", content: "" } }], - })); - } - delete cand.delta.role; - if ("content" in cand.delta) { // prevent empty data (e.g. when MAX_TOKENS) - controller.enqueue(sseline(obj)); - } - cand.finish_reason = finish_reason; - if (data.usageMetadata && this.streamIncludeUsage) { - obj.usage = transformUsage(data.usageMetadata); - } - cand.delta = {}; - this.last[cand.index] = obj; +function toOpenAiStream(line, controller) { + let data; + try { + data = JSON.parse(line); + if (!data.candidates) { + throw new Error("Invalid completion chunk object"); + } + } catch (err) { + console.error("Error parsing response:", err); + if (!this.shared.is_buffers_rest) { line = + delimiter; } + controller.enqueue(line); // output as is + return; + } + const obj = { + id: this.id, + choices: data.candidates.map(transformCandidatesDelta), + //created: Math.floor(Date.now()/1000), + model: data.modelVersion ?? this.model, + //system_fingerprint: "fp_69829325d0", + object: "chat.completion.chunk", + usage: data.usageMetadata && this.streamIncludeUsage ? null : undefined, + }; + if (checkPromptBlock(obj.choices, data.promptFeedback, "delta")) { + controller.enqueue(sseline(obj)); + return; + } + console.assert(data.candidates.length === 1, "Unexpected candidates count: %d", data.candidates.length); + const cand = obj.choices[0]; + cand.index = cand.index || 0; // absent in new -002 models response + const finish_reason = cand.finish_reason; + cand.finish_reason = null; + if (!this.last[cand.index]) { // first + controller.enqueue(sseline({ + ...obj, + choices: [{ ...cand, tool_calls: undefined, delta: { role: "assistant", content: "" } }], + })); + } + delete cand.delta.role; + if ("content" in cand.delta) { // prevent empty data (e.g. when MAX_TOKENS) + controller.enqueue(sseline(obj)); + } + cand.finish_reason = finish_reason; + if (data.usageMetadata && this.streamIncludeUsage) { + obj.usage = transformUsage(data.usageMetadata); + } + cand.delta = {}; + this.last[cand.index] = obj; } -function toOpenAiStreamFlush (controller) { - if (this.last.length > 0) { - for (const obj of this.last) { - controller.enqueue(sseline(obj)); +function toOpenAiStreamFlush(controller) { + if (this.last.length > 0) { + for (const obj of this.last) { + controller.enqueue(sseline(obj)); + } + controller.enqueue("data: [DONE]" + delimiter); } - controller.enqueue("data: [DONE]" + delimiter); - } +} +function handleResponseBody(req) { + const ip = req.headers.get("x-forwarded-for") || req.headers.get("x-real-ip") || req.ip; + return { ip, env: req.__env } +} +function isWhiteIp(env, request) { + const WHITE_IP_LIST = env.WHITE_IP_LIST?.split(",") ?? []; + if (WHITE_IP_LIST.length === 0) { return true; } + + const { ip } = handleResponseBody(request); + return WHITE_IP_LIST.includes(ip); }