diff --git a/.changeset/lazy-cougars-sip.md b/.changeset/lazy-cougars-sip.md new file mode 100644 index 000000000..7b41e0aeb --- /dev/null +++ b/.changeset/lazy-cougars-sip.md @@ -0,0 +1,5 @@ +--- +"@google/generative-ai": patch +--- + +Added custom header support to the Server package, matching functionality on the client package diff --git a/src/server/request.test.ts b/src/server/request.test.ts index dd2c15d25..4a0e8299a 100644 --- a/src/server/request.test.ts +++ b/src/server/request.test.ts @@ -20,9 +20,17 @@ import { match, restore, stub } from "sinon"; import * as sinonChai from "sinon-chai"; import * as chaiAsPromised from "chai-as-promised"; import { DEFAULT_API_VERSION, DEFAULT_BASE_URL } from "../requests/request"; -import { FilesRequestUrl, makeServerRequest } from "./request"; +import { + FilesRequestUrl, + ServerRequestUrl, + getHeaders, + makeServerRequest, +} from "./request"; import { RpcTask } from "./constants"; -import { GoogleGenerativeAIFetchError } from "../errors"; +import { + GoogleGenerativeAIFetchError, + GoogleGenerativeAIRequestInputError, +} from "../errors"; use(sinonChai); use(chaiAsPromised); @@ -212,4 +220,24 @@ describe("Files API - request methods", () => { expect(fetchStub).to.be.calledOnce; }); }); + describe("getHeaders", () => { + it("passes custom headers", () => { + const url = new ServerRequestUrl(RpcTask.GET, "key", { + customHeaders: new Headers({ customHeader: "customHeaderValue" }), + }); + + const headers = getHeaders(url); + + expect(headers.get("customHeader")).to.equal("customHeaderValue"); + }); + it("passes custom x-goog-api-client header", () => { + const url = new ServerRequestUrl(RpcTask.GET, "key", { + customHeaders: new Headers({ "x-goog-api-client": "client/version" }), + }); + + expect(() => getHeaders(url)).to.throw( + GoogleGenerativeAIRequestInputError, + ); + }); + }); }); diff --git a/src/server/request.ts b/src/server/request.ts index 464c1d2e2..e16de647d 100644 --- a/src/server/request.ts +++ b/src/server/request.ts @@ -23,6 +23,7 @@ import { } from "../requests/request"; import { RequestOptions, SingleRequestOptions } from "../../types"; import { RpcTask } from "./constants"; +import { GoogleGenerativeAIRequestInputError } from "../errors"; const taskToMethod = { [RpcTask.UPLOAD]: "POST", @@ -91,6 +92,35 @@ export function getHeaders(url: ServerRequestUrl): Headers { const headers = new Headers(); headers.append("x-goog-api-client", getClientHeaders(url.requestOptions)); headers.append("x-goog-api-key", url.apiKey); + + let customHeaders = url.requestOptions?.customHeaders; + if (customHeaders) { + if (!(customHeaders instanceof Headers)) { + try { + customHeaders = new Headers(customHeaders); + } catch (e) { + throw new GoogleGenerativeAIRequestInputError( + `unable to convert customHeaders value ${JSON.stringify( + customHeaders, + )} to Headers: ${e.message}`, + ); + } + } + + for (const [headerName, headerValue] of customHeaders.entries()) { + if (headerName === "x-goog-api-key") { + throw new GoogleGenerativeAIRequestInputError( + `Cannot set reserved header name ${headerName}`, + ); + } else if (headerName === "x-goog-api-client") { + throw new GoogleGenerativeAIRequestInputError( + `Header name ${headerName} can only be set using the apiClient field`, + ); + } + + headers.append(headerName, headerValue); + } + } return headers; }