Skip to content

Commit 5a868a0

Browse files
authored
Fix AI Search rpc binding (#12543)
1 parent 7ea69af commit 5a868a0

File tree

6 files changed

+238
-2
lines changed

6 files changed

+238
-2
lines changed

.changeset/chatty-rivers-punch.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
"miniflare": minor
3+
"wrangler": minor
4+
---
5+
6+
Add support for AI Search RPC method

packages/miniflare/src/plugins/ai/index.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ export const AI_PLUGIN: Plugin<typeof AIOptionsSchema> = {
7070
),
7171
worker: remoteProxyClientWorker(
7272
options.ai.remoteProxyConnectionString,
73-
options.ai.binding
73+
options.ai.binding,
74+
"ai"
7475
),
7576
},
7677
];

packages/miniflare/src/plugins/shared/constants.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ export function objectEntryWorker(
7575

7676
export function remoteProxyClientWorker(
7777
remoteProxyConnectionString: RemoteProxyConnectionString | undefined,
78-
binding: string
78+
binding: string,
79+
bindingType?: string
7980
) {
8081
return {
8182
compatibilityDate: "2025-01-01",
@@ -98,6 +99,14 @@ export function remoteProxyClientWorker(
9899
name: "binding",
99100
text: binding,
100101
},
102+
...(bindingType
103+
? [
104+
{
105+
name: "bindingType",
106+
text: bindingType,
107+
},
108+
]
109+
: []),
101110
],
102111
};
103112
}

packages/miniflare/src/workers/shared/remote-proxy-client.worker.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { makeFetch } from "./remote-bindings-utils";
55
type Env = {
66
remoteProxyConnectionString?: string;
77
binding: string;
8+
bindingType?: string;
89
};
910
export default class Client extends WorkerEntrypoint<Env> {
1011
async fetch(request: Request) {
@@ -20,6 +21,9 @@ export default class Client extends WorkerEntrypoint<Env> {
2021
const url = new URL(env.remoteProxyConnectionString);
2122
url.protocol = "ws:";
2223
url.searchParams.set("MF-Binding", env.binding);
24+
if (env.bindingType) {
25+
url.searchParams.set("MF-Binding-Type", env.bindingType);
26+
}
2327
stub = newWebSocketRpcSession(url.href);
2428
}
2529

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import { describe, it, vi } from "vitest";
2+
3+
/**
4+
* Tests for the AI RPC method wrapping logic in ProxyServerWorker.
5+
*
6+
* The raw AI binding (deployed with raw:true) has a workerd-internal prototype
7+
* that capnweb classifies as "unsupported", causing
8+
* "RPC stub points at a non-serializable type".
9+
*
10+
* The fix uses MF-Binding-Type (threaded from the miniflare AI plugin through
11+
* the remote-proxy-client WebSocket URL) to identify AI bindings, then wraps
12+
* them in a plain object delegating only the allowed RPC methods.
13+
*/
14+
15+
// Mirrors the constant from ProxyServerWorker.ts
16+
const AI_RPC_METHODS = ["aiSearch"] as const;
17+
18+
/**
19+
* Re-implementation of the AI wrapping logic from
20+
* ProxyServerWorker.ts getExposedJSRPCBinding() so we can unit-test it
21+
* without pulling in cloudflare:email / capnweb.
22+
*/
23+
function wrapIfAiBinding(
24+
bindingType: string | null,
25+
targetBinding: object
26+
): unknown {
27+
if (bindingType === "ai") {
28+
const wrapper: Record<string, (...args: unknown[]) => unknown> = {};
29+
for (const method of AI_RPC_METHODS) {
30+
if (
31+
typeof (targetBinding as Record<string, unknown>)[method] === "function"
32+
) {
33+
wrapper[method] = (...args: unknown[]) =>
34+
(targetBinding as Record<string, (...a: unknown[]) => unknown>)[
35+
method
36+
](...args);
37+
}
38+
}
39+
if (Object.keys(wrapper).length > 0) {
40+
return wrapper;
41+
}
42+
}
43+
return targetBinding;
44+
}
45+
46+
describe("ProxyServerWorker AI RPC wrapping", () => {
47+
it("wraps an AI binding into a plain object", ({ expect }) => {
48+
const binding = { aiSearch: vi.fn(), fetch: vi.fn() };
49+
50+
const wrapped = wrapIfAiBinding("ai", binding);
51+
52+
expect(Object.getPrototypeOf(wrapped)).toBe(Object.prototype);
53+
expect(wrapped).not.toBe(binding);
54+
});
55+
56+
it("delegates aiSearch calls to the underlying binding", async ({
57+
expect,
58+
}) => {
59+
const mockAiSearch = vi.fn().mockResolvedValue({ result: "ok" });
60+
const binding = { aiSearch: mockAiSearch };
61+
62+
const wrapped = wrapIfAiBinding("ai", binding) as Record<
63+
string,
64+
(...args: unknown[]) => unknown
65+
>;
66+
67+
const params = { query: "test" };
68+
const result = await wrapped.aiSearch(params);
69+
70+
expect(mockAiSearch).toHaveBeenCalledWith(params);
71+
expect(result).toEqual({ result: "ok" });
72+
});
73+
74+
it("forwards all arguments to the underlying aiSearch method", ({
75+
expect,
76+
}) => {
77+
const mockAiSearch = vi.fn();
78+
const binding = { aiSearch: mockAiSearch };
79+
const wrapped = wrapIfAiBinding("ai", binding) as Record<
80+
string,
81+
(...args: unknown[]) => unknown
82+
>;
83+
84+
wrapped.aiSearch("arg1", "arg2", { nested: true });
85+
86+
expect(mockAiSearch).toHaveBeenCalledWith("arg1", "arg2", {
87+
nested: true,
88+
});
89+
});
90+
91+
it("does not wrap bindings without the ai binding type", ({ expect }) => {
92+
const binding = { aiSearch: vi.fn(), otherMethod: vi.fn() };
93+
94+
const result = wrapIfAiBinding(null, binding);
95+
96+
expect(result).toBe(binding);
97+
});
98+
99+
it("does not wrap a service binding even if it has aiSearch", ({
100+
expect,
101+
}) => {
102+
const binding = { aiSearch: vi.fn(), otherMethod: vi.fn() };
103+
104+
const result = wrapIfAiBinding("service", binding);
105+
106+
expect(result).toBe(binding);
107+
});
108+
109+
it("does not expose non-allowlisted methods from the raw binding", ({
110+
expect,
111+
}) => {
112+
const binding = {
113+
aiSearch: vi.fn(),
114+
fetch: vi.fn(),
115+
someInternalMethod: vi.fn(),
116+
};
117+
118+
const wrapped = wrapIfAiBinding("ai", binding) as Record<string, unknown>;
119+
120+
expect(wrapped).toHaveProperty("aiSearch");
121+
expect(wrapped).not.toHaveProperty("fetch");
122+
expect(wrapped).not.toHaveProperty("someInternalMethod");
123+
});
124+
125+
it("propagates errors thrown by the underlying aiSearch method", async ({
126+
expect,
127+
}) => {
128+
const binding = {
129+
aiSearch: vi.fn().mockRejectedValue(new Error("AI Search failed")),
130+
};
131+
132+
const wrapped = wrapIfAiBinding("ai", binding) as Record<
133+
string,
134+
(...args: unknown[]) => Promise<unknown>
135+
>;
136+
137+
await expect(wrapped.aiSearch({})).rejects.toThrow("AI Search failed");
138+
});
139+
140+
it("propagates RpcTarget-like return values for multi-level RPC", async ({
141+
expect,
142+
}) => {
143+
class MockAccountService {
144+
async list() {
145+
return [{ id: "instance-1" }];
146+
}
147+
get(name: string) {
148+
return new MockInstanceService(name);
149+
}
150+
}
151+
class MockInstanceService {
152+
constructor(public instanceId: string) {}
153+
async search(params: { query: string }) {
154+
return { chunks: [], search_query: params.query };
155+
}
156+
}
157+
158+
const binding = {
159+
aiSearch: vi.fn().mockReturnValue(new MockAccountService()),
160+
};
161+
162+
const wrapped = wrapIfAiBinding("ai", binding) as Record<
163+
string,
164+
(...args: unknown[]) => unknown
165+
>;
166+
167+
const svc = wrapped.aiSearch() as MockAccountService;
168+
expect(await svc.list()).toEqual([{ id: "instance-1" }]);
169+
170+
const inst = svc.get("my-instance");
171+
expect(await inst.search({ query: "test" })).toEqual({
172+
chunks: [],
173+
search_query: "test",
174+
});
175+
});
176+
177+
it("returns binding as-is when type is ai but no RPC methods exist", ({
178+
expect,
179+
}) => {
180+
const binding = { fetch: vi.fn() };
181+
182+
const result = wrapIfAiBinding("ai", binding);
183+
184+
expect(result).toBe(binding);
185+
});
186+
});

packages/wrangler/templates/remoteBindings/ProxyServerWorker.ts

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,18 @@ import { EmailMessage } from "cloudflare:email";
33

44
interface Env extends Record<string, unknown> {}
55

6+
/**
7+
* List of RPC methods exposed by the raw AI binding that need proxying
8+
* through a plain-object wrapper. The raw AI binding (deployed with raw:true)
9+
* has a non-standard prototype that capnweb's typeForRpc() doesn't recognise,
10+
* causing "RPC stub points at a non-serializable type". By wrapping only the
11+
* allowed RPC methods in a plain object we give capnweb an Object.prototype
12+
* target it can navigate.
13+
*
14+
* Add new AI RPC method names here as they are introduced.
15+
*/
16+
const AI_RPC_METHODS = ["aiSearch"] as const;
17+
618
class BindingNotFoundError extends Error {
719
constructor(name?: string) {
820
super(`Binding ${name ? `"${name}"` : ""} not found`);
@@ -54,6 +66,11 @@ async function evaluateMediaBinding(
5466
* can't emulate that over an async boundary, we mock it locally and _actually_
5567
* perform the .get() remotely at the first appropriate async point. See
5668
* packages/miniflare/src/workers/dispatch-namespace/dispatch-namespace.worker.ts
69+
* - AI bindings (raw:true / minimal_mode) have a workerd-internal prototype
70+
* that capnweb's typeForRpc() classifies as "unsupported", causing
71+
* "RPC stub points at a non-serializable type". We wrap the binding in a
72+
* plain object that delegates only the allowed RPC methods (AI_RPC_METHODS)
73+
* so capnweb gets an Object.prototype target it can navigate.
5774
*
5875
* getExposedJSRPCBinding() and getExposedFetcher() perform the logic for figuring out
5976
* which binding is being accessed, dependending on the request. Note: Both have logic
@@ -93,6 +110,19 @@ function getExposedJSRPCBinding(request: Request, env: Env) {
93110
};
94111
}
95112

113+
if (url.searchParams.get("MF-Binding-Type") === "ai") {
114+
const wrapper: Record<string, (...args: unknown[]) => unknown> = {};
115+
for (const method of AI_RPC_METHODS) {
116+
if (typeof (targetBinding as any)[method] === "function") {
117+
wrapper[method] = (...args: unknown[]) =>
118+
(targetBinding as any)[method](...args);
119+
}
120+
}
121+
if (Object.keys(wrapper).length > 0) {
122+
return wrapper;
123+
}
124+
}
125+
96126
if (url.searchParams.has("MF-Dispatch-Namespace-Options")) {
97127
const { name, args, options } = JSON.parse(
98128
url.searchParams.get("MF-Dispatch-Namespace-Options")!

0 commit comments

Comments
 (0)