-
Notifications
You must be signed in to change notification settings - Fork 780
feat: Create frontend tools framework and integrate to backend ai system #6609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
f3c3410
c8de49c
a9e3fed
a2ba9f3
3d76bff
9394987
0024d42
b9337dd
2da93bc
f0bf09d
be02a6d
5743c94
7c131ff
fd89147
10567e2
f7315e0
0c90284
879298b
8f520dd
d7836e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| /* Copyright 2024 Marimo. All rights reserved. */ | ||
|
|
||
| import type { z } from "zod"; | ||
| import type { AnyZodObject } from "./registry"; | ||
|
|
||
| /** | ||
| * Minimal base class for frontend tools. | ||
| * | ||
| * Structural typing ensures instances are compatible with the Tool<TIn, TOut> | ||
| * interface used by the registry, without importing it here. | ||
| */ | ||
| export abstract class BaseTool< | ||
| TIn extends AnyZodObject, | ||
| TOut extends AnyZodObject, | ||
| > { | ||
| public readonly name: string; | ||
|
||
| public readonly description: string; | ||
| public readonly schema: TIn; | ||
| public readonly outputSchema: TOut; | ||
| public readonly mode: ("manual" | "ask")[]; | ||
|
|
||
| /** | ||
| * Handler exposed to the registry. Calls the subclass implementation. | ||
| */ | ||
| public readonly handler: ( | ||
| args: z.infer<TIn>, | ||
| ) => z.infer<TOut> | Promise<z.infer<TOut>>; | ||
|
|
||
| constructor(options: { | ||
| name: string; | ||
| description: string; | ||
| schema: TIn; | ||
| mode: ("manual" | "ask")[]; | ||
| outputSchema: TOut; | ||
| }) { | ||
| this.name = options.name; | ||
| this.description = options.description; | ||
| this.schema = options.schema; | ||
| this.mode = options.mode; | ||
| this.outputSchema = options.outputSchema; | ||
| this.handler = (args) => Promise.resolve(this.handle(args)); | ||
| } | ||
|
|
||
| /** Implement tool logic in subclasses */ | ||
| protected abstract handle( | ||
| args: z.infer<TIn>, | ||
| ): z.infer<TOut> | Promise<z.infer<TOut>>; | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| /* Copyright 2024 Marimo. All rights reserved. */ | ||
|
|
||
| import { type ZodObject, z } from "zod"; | ||
| import type { BaseTool } from "./base"; | ||
| import { testFrontendTool } from "./sample-tool"; | ||
|
|
||
| export type AnyZodObject = ZodObject<z.ZodRawShape>; | ||
|
|
||
| interface StoredTool { | ||
| /** Generic type for to avoid type errors */ | ||
| name: string; | ||
| description: string; | ||
| schema: AnyZodObject; | ||
| outputSchema: AnyZodObject; | ||
| mode: CopilotMode[]; | ||
| handler: (args: unknown) => Promise<unknown>; | ||
| } | ||
|
|
||
| /** should be the same as marimo/_config/config.py > CopilotMode */ | ||
| export type CopilotMode = "manual" | "ask"; | ||
|
|
||
| export interface FrontendToolDefinition { | ||
|
||
| /** should be the same as marimo/_server/ai/tools/types.py > ToolDefinition */ | ||
| name: string; | ||
| description: string; | ||
| parameters: Record<string, unknown>; | ||
| source: "frontend"; | ||
| mode: CopilotMode[]; | ||
| } | ||
|
|
||
| export class FrontendToolRegistry { | ||
| /** All registered tools */ | ||
| private tools = new Map<string, StoredTool>(); | ||
|
|
||
| registerAll<TIn extends AnyZodObject, TOut extends AnyZodObject>( | ||
|
||
| tools: BaseTool<TIn, TOut>[], | ||
| ) { | ||
| tools.forEach((tool) => { | ||
| this.register(tool); | ||
| }); | ||
| } | ||
|
|
||
| private register<TIn extends AnyZodObject, TOut extends AnyZodObject>( | ||
| tool: BaseTool<TIn, TOut>, | ||
| ) { | ||
| // Make type generic to avoid type errors | ||
| // Let invoke() handle runtime type checking | ||
| const stored: StoredTool = { | ||
| name: tool.name, | ||
| description: tool.description, | ||
| schema: tool.schema, | ||
| outputSchema: tool.outputSchema, | ||
| mode: tool.mode, | ||
| handler: tool.handler as (args: unknown) => Promise<unknown>, | ||
| }; | ||
| this.tools.set(tool.name, stored); | ||
| } | ||
|
|
||
| has(toolName: string) { | ||
| return this.tools.has(toolName); | ||
| } | ||
|
|
||
| private getTool(toolName: string): StoredTool { | ||
|
||
| const tool = this.tools.get(toolName); | ||
| if (!tool) { | ||
| throw new Error(`Tool ${toolName} not found`); | ||
| } | ||
| return tool; | ||
| } | ||
|
|
||
| async invoke<TName extends string>( | ||
| toolName: TName, | ||
| rawArgs: unknown, | ||
| ): Promise<unknown> { | ||
| const tool = this.getTool(toolName); | ||
| const handler = tool.handler; | ||
| const inputSchema = tool.schema; | ||
| const outputSchema = tool.outputSchema; | ||
|
|
||
| try { | ||
| // Parse input args | ||
| const inputResponse = await inputSchema.safeParseAsync(rawArgs); | ||
| if (inputResponse.error) { | ||
| const strError = z.prettifyError(inputResponse.error); | ||
| throw new Error(`Tool ${toolName} returned invalid input: ${strError}`); | ||
| } | ||
| const args = inputResponse.data; | ||
|
|
||
| // Call the handler | ||
| const rawOutput = await handler(args); | ||
|
|
||
| // Parse output | ||
| const response = await outputSchema.safeParseAsync(rawOutput); | ||
| if (response.error) { | ||
| const strError = z.prettifyError(response.error); | ||
| throw new Error( | ||
| `Tool ${toolName} returned invalid output: ${strError}`, | ||
| ); | ||
| } | ||
| const output = response.data; | ||
| return output; | ||
| } catch (error) { | ||
| return { | ||
| status: "error", | ||
| code: "TOOL_ERROR", | ||
| message: error instanceof Error ? error.message : String(error), | ||
| suggestedFix: "Try again with valid arguments.", | ||
| meta: { | ||
| args: rawArgs, | ||
| }, | ||
| }; | ||
| } | ||
| } | ||
|
|
||
| getToolSchemas(): FrontendToolDefinition[] { | ||
| return [...this.tools.values()].map((tool) => ({ | ||
| name: tool.name, | ||
| description: tool.description, | ||
| parameters: z.toJSONSchema(tool.schema), | ||
| source: "frontend", | ||
| mode: tool.mode, | ||
| })); | ||
| } | ||
| } | ||
|
|
||
| export const FRONTEND_TOOL_REGISTRY = new FrontendToolRegistry(); | ||
|
|
||
| /* Register all the frontend tools */ | ||
| FRONTEND_TOOL_REGISTRY.registerAll([ | ||
|
||
| testFrontendTool, | ||
| // ADD MORE TOOLS HERE | ||
| ]); | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| /* Copyright 2024 Marimo. All rights reserved. */ | ||
|
|
||
| import { z } from "zod"; | ||
| import { BaseTool } from "./base"; | ||
|
|
||
| const schema = z.object({ name: z.string() }); | ||
| const outputSchema = z.object({ message: z.string() }); | ||
|
|
||
| /** A sample frontend tool that returns "hello world" */ | ||
| export class TestFrontendTool extends BaseTool< | ||
| typeof schema, | ||
| typeof outputSchema | ||
| > { | ||
| constructor() { | ||
| super({ | ||
| name: "test_frontend_tool", | ||
| description: | ||
| "A test frontend tool that returns hi with the name passed in", | ||
| schema, | ||
| outputSchema, | ||
| mode: ["ask"], | ||
| }); | ||
| } | ||
|
|
||
| protected async handle({ name }: z.infer<typeof schema>) { | ||
| return { message: `Hello: ${name}` }; | ||
| } | ||
| } | ||
|
|
||
| export const testFrontendTool = new TestFrontendTool(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we validate the mode here? or not needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think its needed since it's already hard typed. I'm more worried about the CopilotMode type in the frontend going out of sync with the backend.