|
| 1 | +/* Copyright 2024 Marimo. All rights reserved. */ |
| 2 | + |
| 3 | +import { z } from "zod"; |
| 4 | +import { notebookAtom } from "@/core/cells/cells"; |
| 5 | +import type { CellId } from "@/core/cells/ids"; |
| 6 | +import { updateEditorCodeFromPython } from "@/core/codemirror/language/utils"; |
| 7 | +import type { JotaiStore } from "@/core/state/jotai"; |
| 8 | +import { stagedAICellsAtom } from "../staged-cells"; |
| 9 | +import { |
| 10 | + type AiTool, |
| 11 | + ToolExecutionError, |
| 12 | + type ToolOutputBase, |
| 13 | + toolOutputBaseSchema, |
| 14 | +} from "./base"; |
| 15 | +import type { CopilotMode } from "./registry"; |
| 16 | + |
| 17 | +const description = ` |
| 18 | +Perform editing operations on the current notebook. |
| 19 | +Call this tool multiple times to perform multiple edits. |
| 20 | +
|
| 21 | +Args: |
| 22 | +- edit (object): The editing operation to perform. Must be one of: |
| 23 | + - update_cell: Update the code of an existing cell. |
| 24 | + - add_cell: Add a new cell to the notebook. |
| 25 | + - delete_cell: Delete an existing cell. |
| 26 | +
|
| 27 | +Returns: |
| 28 | +- A result object containing standard tool metadata. |
| 29 | +`; |
| 30 | + |
| 31 | +const editNotebookSchema = z.object({ |
| 32 | + edit: z.discriminatedUnion("type", [ |
| 33 | + z.object({ |
| 34 | + type: z.literal("update_cell"), |
| 35 | + cellId: z.string() as unknown as z.ZodType<CellId>, |
| 36 | + code: z.string(), |
| 37 | + }), |
| 38 | + z.object({ |
| 39 | + type: z.literal("add_cell"), |
| 40 | + cellId: z.string() as unknown as z.ZodType<CellId>, |
| 41 | + code: z.string(), |
| 42 | + language: z.enum(["python", "sql", "markdown"]).optional(), |
| 43 | + }), |
| 44 | + z.object({ |
| 45 | + type: z.literal("delete_cell"), |
| 46 | + cellId: z.string() as unknown as z.ZodType<CellId>, |
| 47 | + }), |
| 48 | + ]), |
| 49 | +}); |
| 50 | + |
| 51 | +type EditNotebookInput = z.infer<typeof editNotebookSchema>; |
| 52 | +type EditOperation = EditNotebookInput["edit"]; |
| 53 | +export type EditType = EditOperation["type"]; |
| 54 | + |
| 55 | +export class EditNotebookTool |
| 56 | + implements AiTool<EditNotebookInput, ToolOutputBase> |
| 57 | +{ |
| 58 | + private readonly store: JotaiStore; |
| 59 | + readonly name = "edit_notebook_tool"; |
| 60 | + readonly description = description; |
| 61 | + readonly schema = editNotebookSchema; |
| 62 | + readonly outputSchema = toolOutputBaseSchema; |
| 63 | + readonly mode: CopilotMode[] = ["agent"]; |
| 64 | + |
| 65 | + constructor(store: JotaiStore) { |
| 66 | + this.store = store; |
| 67 | + } |
| 68 | + |
| 69 | + handler = async ({ edit }: EditNotebookInput): Promise<ToolOutputBase> => { |
| 70 | + switch (edit.type) { |
| 71 | + case "update_cell": { |
| 72 | + const { cellId, code } = edit; |
| 73 | + |
| 74 | + const notebook = this.store.get(notebookAtom); |
| 75 | + const cellIds = notebook.cellIds; |
| 76 | + if (!cellIds.getColumns().some((column) => column.idSet.has(cellId))) { |
| 77 | + throw new ToolExecutionError( |
| 78 | + "Cell not found", |
| 79 | + "CELL_NOT_FOUND", |
| 80 | + false, |
| 81 | + "Check which cells exist in the notebook", |
| 82 | + ); |
| 83 | + } |
| 84 | + |
| 85 | + const cellHandles = notebook.cellHandles; |
| 86 | + const cellHandle = cellHandles[cellId].current; |
| 87 | + if (!cellHandle?.editorView) { |
| 88 | + throw new ToolExecutionError( |
| 89 | + "Cell editor not found", |
| 90 | + "CELL_EDITOR_NOT_FOUND", |
| 91 | + false, |
| 92 | + "Internal error, ask the user to report this error", |
| 93 | + ); |
| 94 | + } |
| 95 | + |
| 96 | + const currentCellCode = cellHandle.editorView.state.doc.toString(); |
| 97 | + |
| 98 | + const stagedAICells = this.store.get(stagedAICellsAtom); |
| 99 | + const newStagedAICells = new Map([ |
| 100 | + ...stagedAICells, |
| 101 | + [cellId, { type: "update_cell", previousCode: currentCellCode }], |
| 102 | + ]); |
| 103 | + this.store.set(stagedAICellsAtom, newStagedAICells); |
| 104 | + |
| 105 | + updateEditorCodeFromPython(cellHandle.editorView, code); |
| 106 | + |
| 107 | + break; |
| 108 | + } |
| 109 | + case "add_cell": |
| 110 | + // TODO |
| 111 | + break; |
| 112 | + case "delete_cell": |
| 113 | + // TODO |
| 114 | + break; |
| 115 | + } |
| 116 | + return { |
| 117 | + status: "success", |
| 118 | + }; |
| 119 | + }; |
| 120 | +} |
0 commit comments