diff --git a/frontend/src/components/chat/chat-panel.tsx b/frontend/src/components/chat/chat-panel.tsx index 6d223f12d6d..e92fc193def 100644 --- a/frontend/src/components/chat/chat-panel.tsx +++ b/frontend/src/components/chat/chat-panel.tsx @@ -36,7 +36,10 @@ import { type ChatId, chatStateAtom, } from "@/core/ai/state"; -import { FRONTEND_TOOL_REGISTRY } from "@/core/ai/tools/registry"; +import { + type CopilotMode, + FRONTEND_TOOL_REGISTRY, +} from "@/core/ai/tools/registry"; import { aiAtom, aiEnabledAtom } from "@/core/config/config"; import { DEFAULT_AI_MODEL } from "@/core/config/config-schema"; import { FeatureFlagged } from "@/core/config/feature-flag"; @@ -301,7 +304,11 @@ const ChatInputFooter: React.FC = memo( const { saveModeChange } = useModelChange(); - const modeOptions = [ + const modeOptions: { + value: CopilotMode; + label: string; + subtitle: string; + }[] = [ { value: "ask", label: "Ask", @@ -313,6 +320,11 @@ const ChatInputFooter: React.FC = memo( label: "Manual", subtitle: "Pure chat, no tool usage", }, + { + value: "agent", + label: "Agent (beta)", + subtitle: "Use AI with access to read and write tools", + }, ]; const isAttachmentSupported = diff --git a/frontend/src/components/editor/ai/ai-completion-editor.tsx b/frontend/src/components/editor/ai/ai-completion-editor.tsx index 58c104fa2e0..a5c5084f422 100644 --- a/frontend/src/components/editor/ai/ai-completion-editor.tsx +++ b/frontend/src/components/editor/ai/ai-completion-editor.tsx @@ -11,14 +11,16 @@ import { customPythonLanguageSupport } from "@/core/codemirror/language/language import "./merge-editor.css"; import { storePrompt } from "@marimo-team/codemirror-ai"; import type { ReactCodeMirrorRef } from "@uiw/react-codemirror"; -import { useAtom } from "jotai"; +import { useAtom, useAtomValue } from "jotai"; import { AIModelDropdown } from "@/components/ai/ai-model-dropdown"; import { Checkbox } from "@/components/ui/checkbox"; import { Label } from "@/components/ui/label"; import { Switch } from "@/components/ui/switch"; import { Tooltip } from "@/components/ui/tooltip"; import { toast } from "@/components/ui/use-toast"; -import { includeOtherCellsAtom } from "@/core/ai/state"; +import { stagedAICellsAtom } from "@/core/ai/staged-cells"; +import { type AiCompletionCell, includeOtherCellsAtom } from "@/core/ai/state"; +import type { CellId } from "@/core/cells/ids"; import { getCodes } from "@/core/codemirror/copilot/getCodes"; import type { LanguageAdapterType } from "@/core/codemirror/language/types"; import { selectAllText } from "@/core/codemirror/utils"; @@ -40,15 +42,14 @@ const Original = CodeMirrorMerge.Original; const Modified = CodeMirrorMerge.Modified; interface Props { + cellId: CellId; + aiCompletionCell: AiCompletionCell | null; className?: string; currentCode: string; currentLanguageAdapter: LanguageAdapterType | undefined; - initialPrompt: string | undefined; onChange: (code: string) => void; declineChange: () => void; acceptChange: (rightHandCode: string) => void; - enabled: boolean; - triggerImmediately?: boolean; runCell: () => void; outputArea?: "above" | "below"; /** @@ -65,15 +66,14 @@ const baseExtensions = [customPythonLanguageSupport(), EditorView.lineWrapping]; * This shows a left/right split with the original and modified code. */ export const AiCompletionEditor: React.FC = ({ + cellId, + aiCompletionCell, className, onChange, - initialPrompt, currentLanguageAdapter, currentCode, declineChange, acceptChange, - enabled, - triggerImmediately, runCell, outputArea, children, @@ -88,6 +88,20 @@ export const AiCompletionEditor: React.FC = ({ const runtimeManager = useRuntimeManager(); + const { + initialPrompt, + triggerImmediately, + cellId: aiCellId, + } = aiCompletionCell ?? {}; + const enabled = aiCellId === cellId; + + const stagedAICells = useAtomValue(stagedAICellsAtom); + const updatedCell = stagedAICells.get(cellId); + let previousCellCode: string | undefined; + if (updatedCell?.type === "update_cell") { + previousCellCode = updatedCell.previousCode; + } + const { completion: untrimmedCompletion, input, @@ -178,6 +192,8 @@ export const AiCompletionEditor: React.FC = ({ const showCompletionBanner = enabled && triggerImmediately && (completion || isLoading); + // Set default output area to below if not specified + outputArea = outputArea ?? "below"; const showInput = enabled && (!triggerImmediately || showInputPrompt); @@ -202,6 +218,35 @@ export const AiCompletionEditor: React.FC = ({ ); + const renderMergeEditor = (originalCode: string, modifiedCode: string) => { + return ( + + + + + ); + }; + + const renderCompletionEditor = () => { + if (completion && enabled) { + return renderMergeEditor(currentCode, completion); + } + // If there is no completion and there is previous cell code, it means there is an AI change to the cell. + // And we want to render the previous cell code as the original + if (!completion && previousCellCode) { + return renderMergeEditor(previousCellCode, currentCode); + } + }; + return (
= ({ )}
{outputArea === "above" && completionBanner} - {completion && enabled && ( - - - - - )} - {(!completion || !enabled) && children} + {renderCompletionEditor()} + {(!completion || !enabled) && !previousCellCode && children} {/* By default, show the completion banner below the code */} - {(outputArea === "below" || !outputArea) && completionBanner} + {outputArea === "below" && completionBanner}
); }; diff --git a/frontend/src/components/editor/ai/completion-handlers.tsx b/frontend/src/components/editor/ai/completion-handlers.tsx index 7b061b7a055..f4bd85fa638 100644 --- a/frontend/src/components/editor/ai/completion-handlers.tsx +++ b/frontend/src/components/editor/ai/completion-handlers.tsx @@ -114,6 +114,7 @@ export const CompletionActionsCellFooter: React.FC<{ export const AcceptCompletionButton: React.FC<{ isLoading: boolean; onAccept: () => void; + text?: string; size?: "xs" | "sm"; buttonStyles?: string; playButtonStyles?: string; @@ -122,6 +123,7 @@ export const AcceptCompletionButton: React.FC<{ }> = ({ isLoading, onAccept, + text = "Accept", size = "sm", buttonStyles, acceptShortcut, @@ -150,7 +152,7 @@ export const AcceptCompletionButton: React.FC<{ onClick={onAccept} className={`${baseClasses} rounded-r-none ${buttonStyles}`} > - Accept + {text} {acceptShortcut && ( )} @@ -178,7 +180,7 @@ export const AcceptCompletionButton: React.FC<{ onClick={onAccept} className={`${baseClasses} rounded px-3 ${buttonStyles}`} > - Accept + {text} {acceptShortcut && ( )} @@ -188,10 +190,17 @@ export const AcceptCompletionButton: React.FC<{ export const RejectCompletionButton: React.FC<{ onDecline: () => void; + text?: string; size?: "xs" | "sm"; className?: string; declineShortcut?: string; -}> = ({ onDecline, size = "sm", className, declineShortcut }) => { +}> = ({ + onDecline, + text = "Reject", + size = "sm", + className, + declineShortcut, +}) => { return ( + + {currentIndex === null + ? `${listStagedCells.length} pending` + : `${currentIndex + 1} / ${listStagedCells.length}`} + + + + +
+ +
+ + +
+
+ ); +}; diff --git a/frontend/src/core/ai/__tests__/staged-cells.test.ts b/frontend/src/core/ai/__tests__/staged-cells.test.ts index 54a047c41a4..3a2f9328693 100644 --- a/frontend/src/core/ai/__tests__/staged-cells.test.ts +++ b/frontend/src/core/ai/__tests__/staged-cells.test.ts @@ -6,6 +6,7 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; import { CellId } from "@/core/cells/ids"; import { updateEditorCodeFromPython } from "../../codemirror/language/utils"; import { + type StagedAICells, stagedAICellsAtom, useStagedCells, visibleForTesting, @@ -35,6 +36,7 @@ vi.mock("../../cells/cells", () => ({ cellHandleAtom: vi.fn(() => ({ read: vi.fn(() => mockCellHandle), })), + getCellEditorView: vi.fn(() => mockCellHandle.current.editorViewOrNull), })); vi.mock("@/components/editor/cell/useDeleteCell", () => ({ @@ -66,32 +68,77 @@ describe("staged-cells", () => { vi.clearAllMocks(); // Reset the atom state - store.set(stagedAICellsAtom, new Set()); + store.set(stagedAICellsAtom, new Map()); }); describe("reducer and actions", () => { it("should initialize with empty map", () => { const state = initialState(); - expect(state).toEqual(new Set()); + expect(state).toEqual(new Map()); }); - it("should add cell IDs", () => { + it("should add cells with update_cell edit", () => { let state = initialState(); state = reducer(state, { type: "addStagedCell", - payload: { cellId: cellId1 }, + payload: { + cellId: cellId1, + edit: { type: "update_cell", previousCode: "old code 1" }, + }, }); state = reducer(state, { type: "addStagedCell", - payload: { cellId: cellId2 }, + payload: { + cellId: cellId2, + edit: { type: "update_cell", previousCode: "old code 2" }, + }, }); expect(state.has(cellId1)).toBe(true); expect(state.has(cellId2)).toBe(true); + expect(state.get(cellId1)).toEqual({ + type: "update_cell", + previousCode: "old code 1", + }); + expect(state.get(cellId2)).toEqual({ + type: "update_cell", + previousCode: "old code 2", + }); + }); + + it("should add cells with add_cell edit", () => { + let state = initialState(); + state = reducer(state, { + type: "addStagedCell", + payload: { cellId: cellId1, edit: { type: "add_cell" } }, + }); + + expect(state.has(cellId1)).toBe(true); + expect(state.get(cellId1)).toEqual({ type: "add_cell" }); + }); + + it("should add cells with delete_cell edit", () => { + let state = initialState(); + state = reducer(state, { + type: "addStagedCell", + payload: { + cellId: cellId1, + edit: { type: "delete_cell", previousCode: "deleted code" }, + }, + }); + + expect(state.has(cellId1)).toBe(true); + expect(state.get(cellId1)).toEqual({ + type: "delete_cell", + previousCode: "deleted code", + }); }); it("should remove cell IDs", () => { - const state = new Set([cellId1, cellId2]); + const state = new Map([ + [cellId1, { type: "add_cell" as const }], + [cellId2, { type: "add_cell" as const }], + ]); const newState = reducer(state, { type: "removeStagedCell", payload: cellId1, @@ -101,23 +148,26 @@ describe("staged-cells", () => { expect(newState.has(cellId2)).toBe(true); }); - it("should clear all cell IDs", () => { - const state = new Set([cellId1, cellId2]); + it("should clear all cells", () => { + const state = new Map([ + [cellId1, { type: "add_cell" as const }], + [cellId2, { type: "add_cell" as const }], + ]); const newState = reducer(state, { type: "clearStagedCells", payload: undefined, }); - expect(newState).toEqual(new Set()); + expect(newState).toEqual(new Map()); }); - it("should not mutate original state", () => { - const state = new Set([cellId1]); + it("should not mutate original state when adding", () => { + const state = new Map([[cellId1, { type: "add_cell" as const }]]); const originalSize = state.size; reducer(state, { type: "addStagedCell", - payload: { cellId: cellId2 }, + payload: { cellId: cellId2, edit: { type: "add_cell" } }, }); expect(state.size).toBe(originalSize); @@ -125,6 +175,23 @@ describe("staged-cells", () => { expect(state.has(cellId2)).toBe(false); }); + it("should not mutate original state when removing", () => { + const state = new Map([ + [cellId1, { type: "add_cell" as const }], + [cellId2, { type: "add_cell" as const }], + ]); + const originalSize = state.size; + + reducer(state, { + type: "removeStagedCell", + payload: cellId1, + }); + + expect(state.size).toBe(originalSize); + expect(state.has(cellId1)).toBe(true); + expect(state.has(cellId2)).toBe(true); + }); + it("should create action functions", () => { const mockDispatch = vi.fn(); const actions = createActions(mockDispatch); @@ -136,7 +203,7 @@ describe("staged-cells", () => { it("should initialize atom with empty map", () => { const state = store.get(stagedAICellsAtom); - expect(state).toEqual(new Set()); + expect(state).toEqual(new Map()); }); }); @@ -181,7 +248,11 @@ describe("staged-cells", () => { it("should delete all staged cells when cells exist", () => { // First set the atom state before rendering the hook - store.set(stagedAICellsAtom, new Set([cellId1, cellId2])); + const initialState: StagedAICells = new Map([ + [cellId1, { type: "add_cell" }], + [cellId2, { type: "add_cell" }], + ]); + store.set(stagedAICellsAtom, initialState); const { result } = renderHook(() => useStagedCells(store)); result.current.deleteAllStagedCells(); @@ -192,76 +263,130 @@ describe("staged-cells", () => { // Verify cells were cleared from the atom const state = store.get(stagedAICellsAtom); - expect(state).toEqual(new Set()); + expect(state).toEqual(new Map()); }); - }); - it("should add staged cell", () => { - const { result } = renderHook(() => useStagedCells(store)); + it("should add staged cell with edit info", () => { + const { result } = renderHook(() => useStagedCells(store)); - result.current.addStagedCell({ cellId: cellId1 }); + result.current.addStagedCell({ + cellId: cellId1, + edit: { type: "update_cell", previousCode: "old code" }, + }); - // Check that the cell was added to the atom - const state = store.get(stagedAICellsAtom); - expect(state.has(cellId1)).toBe(true); - }); + // Check that the cell was added to the atom with edit info + const state = store.get(stagedAICellsAtom); + expect(state.has(cellId1)).toBe(true); + expect(state.get(cellId1)).toEqual({ + type: "update_cell", + previousCode: "old code", + }); + }); - it("should remove staged cell", () => { - const { result } = renderHook(() => useStagedCells(store)); + it("should remove staged cell", () => { + const { result } = renderHook(() => useStagedCells(store)); - // First add cells - result.current.addStagedCell({ cellId: cellId1 }); - result.current.addStagedCell({ cellId: cellId2 }); + // First add cells + result.current.addStagedCell({ + cellId: cellId1, + edit: { type: "add_cell" }, + }); + result.current.addStagedCell({ + cellId: cellId2, + edit: { type: "add_cell" }, + }); - // Then remove one - result.current.removeStagedCell(cellId1); + // Then remove one + result.current.removeStagedCell(cellId1); - // Check that only the remaining cell is in the map - const state = store.get(stagedAICellsAtom); - expect(state.has(cellId1)).toBe(false); - expect(state.has(cellId2)).toBe(true); - }); + // Check that only the remaining cell is in the map + const state = store.get(stagedAICellsAtom); + expect(state.has(cellId1)).toBe(false); + expect(state.has(cellId2)).toBe(true); + }); - it("should clear all staged cells", () => { - const { result } = renderHook(() => useStagedCells(store)); + it("should clear all staged cells", () => { + const { result } = renderHook(() => useStagedCells(store)); - // First add some cells - result.current.addStagedCell({ cellId: cellId1 }); - result.current.addStagedCell({ cellId: cellId2 }); + // First add some cells + result.current.addStagedCell({ + cellId: cellId1, + edit: { type: "add_cell" }, + }); + result.current.addStagedCell({ + cellId: cellId2, + edit: { type: "add_cell" }, + }); - // Then clear all - result.current.clearStagedCells(); + // Then clear all + result.current.clearStagedCells(); - // Check that no cells remain - const state = store.get(stagedAICellsAtom); - expect(state).toEqual(new Set()); - }); + // Check that no cells remain + const state = store.get(stagedAICellsAtom); + expect(state).toEqual(new Map()); + }); - it("should handle multiple operations correctly", () => { - const { result } = renderHook(() => useStagedCells(store)); + it("should handle multiple operations correctly", () => { + const { result } = renderHook(() => useStagedCells(store)); - // Create a staged cell - const mockCellId = "mock-cell-id" as CellId; - vi.mocked(CellId.create).mockReturnValue(mockCellId); + // Create a staged cell + const mockCellId = "mock-cell-id" as CellId; + vi.mocked(CellId.create).mockReturnValue(mockCellId); - const createdCellId = result.current.createStagedCell("test code"); + const createdCellId = result.current.createStagedCell("test code"); - // Verify it was created and added - expect(createdCellId).toBe(mockCellId); - expect(mockCreateNewCell).toHaveBeenCalled(); + // Verify it was created and added + expect(createdCellId).toBe(mockCellId); + expect(mockCreateNewCell).toHaveBeenCalled(); - let state = store.get(stagedAICellsAtom); - expect(state.has(mockCellId)).toBe(true); + let state = store.get(stagedAICellsAtom); + expect(state.has(mockCellId)).toBe(true); + expect(state.get(mockCellId)).toEqual({ type: "add_cell" }); - // Delete the staged cell - result.current.deleteStagedCell(mockCellId); - expect(mockDeleteCellCallback).toHaveBeenCalledWith({ - cellId: mockCellId, + // Delete the staged cell + result.current.deleteStagedCell(mockCellId); + expect(mockDeleteCellCallback).toHaveBeenCalledWith({ + cellId: mockCellId, + }); + + // Verify it was removed from staged cells + state = store.get(stagedAICellsAtom); + expect(state.has(mockCellId)).toBe(false); }); - // Verify it was removed from staged cells - state = store.get(stagedAICellsAtom); - expect(state.has(mockCellId)).toBe(false); + it("should track edit history for updated cells", () => { + const { result } = renderHook(() => useStagedCells(store)); + + // Add a cell with update_cell edit type + result.current.addStagedCell({ + cellId: cellId1, + edit: { type: "update_cell", previousCode: "previous code" }, + }); + + const state = store.get(stagedAICellsAtom); + const edit = state.get(cellId1); + expect(edit).toEqual({ + type: "update_cell", + previousCode: "previous code", + }); + }); + + it("should track edit history for deleted cells", () => { + const { result } = renderHook(() => useStagedCells(store)); + + // Add a cell with delete_cell edit type + result.current.addStagedCell({ + cellId: cellId1, + edit: { type: "delete_cell", previousCode: "deleted content" }, + }); + + const state = store.get(stagedAICellsAtom); + const edit = state.get(cellId1); + expect(edit).toEqual({ + type: "delete_cell", + previousCode: "deleted content", + }); + }); }); }); diff --git a/frontend/src/core/ai/config.ts b/frontend/src/core/ai/config.ts index d7f1abde916..c78a31a063b 100644 --- a/frontend/src/core/ai/config.ts +++ b/frontend/src/core/ai/config.ts @@ -4,7 +4,11 @@ import type { Role } from "@marimo-team/llm-info"; import { useAtom } from "jotai"; import type { QualifiedModelId } from "@/core/ai/ids/ids"; import { userConfigAtom } from "@/core/config/config"; -import type { AIModelKey, UserConfig } from "@/core/config/config-schema"; +import type { + AIModelKey, + CopilotMode, + UserConfig, +} from "@/core/config/config-schema"; import { useRequestClient } from "@/core/network/requests"; // Extract only the supported roles from the Role type @@ -60,7 +64,7 @@ export const useModelChange = () => { saveConfig(newConfig); }; - const saveModeChange = async (newMode: "ask" | "manual") => { + const saveModeChange = async (newMode: CopilotMode) => { const newConfig: UserConfig = { ...userConfig, ai: { diff --git a/frontend/src/core/ai/staged-cells.ts b/frontend/src/core/ai/staged-cells.ts index a8e77e3cd95..f25ac18a654 100644 --- a/frontend/src/core/ai/staged-cells.ts +++ b/frontend/src/core/ai/staged-cells.ts @@ -13,21 +13,30 @@ import { Logger } from "@/utils/Logger"; import { maybeAddMarimoImport } from "../cells/add-missing-import"; import { type CreateNewCellAction, - cellHandleAtom, + getCellEditorView, useCellActions, } from "../cells/cells"; import type { LanguageAdapterType } from "../codemirror/language/types"; import { updateEditorCodeFromPython } from "../codemirror/language/utils"; import type { JotaiStore } from "../state/jotai"; +import type { EditType } from "./tools/edit-notebook-tool"; /** * Cells that are staged for AI completion - * They function similarly to cells in the notebook, but they can be deleted or accepted by the user. - * We only track one set of staged cells at a time. + * They function similarly to cells in the notebook, but they can be accepted or rejected by the user. + * We track edited, new and deleted cells. + * And we only track one set of staged cells at a time. */ -const initialState = (): Set => { - return new Set(); +export type Edit = + | { type: Extract; previousCode: string } + | { type: Extract } + | { type: Extract; previousCode: string }; + +export type StagedAICells = Map; + +const initialState = (): StagedAICells => { + return new Map(); }; const { @@ -36,18 +45,25 @@ const { createActions, reducer, } = createReducerAndAtoms(initialState, { - addStagedCell: (state, action: { cellId: CellId }) => { - const { cellId } = action; - return new Set([...state, cellId]); + addStagedCell: (state, action: { cellId: CellId; edit: Edit }) => { + const { cellId, edit } = action; + return new Map([...state, [cellId, edit]]); }, removeStagedCell: (state, cellId: CellId) => { - return new Set([...state].filter((id) => id !== cellId)); + const newState = new Map(state); + newState.delete(cellId); + return newState; }, clearStagedCells: () => { return initialState(); }, }); +export { + createActions as createStagedAICellsActions, + reducer as stagedAICellsReducer, +}; + interface UpdateStagedCellAction { cellId: CellId; code: string; @@ -67,7 +83,7 @@ export function useStagedCells(store: JotaiStore) { const createStagedCell = (code: string): CellId => { const newCellId = CellId.create(); - addStagedCell({ cellId: newCellId }); + addStagedCell({ cellId: newCellId, edit: { type: "add_cell" } }); createNewCell({ cellId: "__end__", code, @@ -86,8 +102,7 @@ export function useStagedCells(store: JotaiStore) { return; } - const cellHandle = store.get(cellHandleAtom(cellId)); - const editorView = cellHandle?.current?.editorViewOrNull; + const editorView = getCellEditorView(cellId); if (!editorView) { Logger.error("Editor for this cell not found", { cellId }); return; @@ -105,7 +120,7 @@ export function useStagedCells(store: JotaiStore) { // Delete all staged cells and the corresponding cells in the notebook. const deleteAllStagedCells = () => { const stagedAICells = store.get(stagedAICellsAtom); - for (const cellId of stagedAICells) { + for (const cellId of stagedAICells.keys()) { deleteCellCallback({ cellId }); } clearStagedCells(); @@ -175,14 +190,14 @@ class CellCreationStream { private onCreateCell: (code: string) => CellId; private onUpdateCell: (opts: UpdateStagedCellAction) => void; - private addStagedCell: (payload: { cellId: CellId }) => void; + private addStagedCell: (payload: { cellId: CellId; edit: Edit }) => void; private createNewCell: (opts: CreateNewCellAction) => void; private hasMarimoImport = false; constructor( onCreateCell: (code: string) => CellId, onUpdateCell: (opts: UpdateStagedCellAction) => void, - addStagedCell: (payload: { cellId: CellId }) => void, + addStagedCell: (payload: { cellId: CellId; edit: Edit }) => void, createNewCell: (opts: CreateNewCellAction) => void, ) { this.onCreateCell = onCreateCell; @@ -229,7 +244,7 @@ class CellCreationStream { before: true, }); if (cellId) { - this.addStagedCell({ cellId }); + this.addStagedCell({ cellId, edit: { type: "add_cell" } }); } this.hasMarimoImport = true; } diff --git a/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts b/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts new file mode 100644 index 00000000000..333df23e208 --- /dev/null +++ b/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts @@ -0,0 +1,468 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import { EditorState } from "@codemirror/state"; +import { EditorView } from "@codemirror/view"; +import { getDefaultStore } from "jotai"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { MockNotebook } from "@/__mocks__/notebook"; +import { notebookAtom } from "@/core/cells/cells"; +import type { CellId } from "@/core/cells/ids"; +import { OverridingHotkeyProvider } from "@/core/hotkeys/hotkeys"; +import type { CellColumnId } from "@/utils/id-tree"; +import { MultiColumn } from "@/utils/id-tree"; +import { cellConfigExtension } from "../../../codemirror/config/extension"; +import { adaptiveLanguageConfiguration } from "../../../codemirror/language/extension"; +import { stagedAICellsAtom } from "../../staged-cells"; +import { ToolExecutionError } from "../base"; +import { EditNotebookTool } from "../edit-notebook-tool"; + +// Mock scrollAndHighlightCell +vi.mock("@/components/editor/links/cell-link", () => ({ + scrollAndHighlightCell: vi.fn(), +})); + +// Mock updateEditorCodeFromPython +vi.mock("@/core/codemirror/language/utils", () => ({ + updateEditorCodeFromPython: vi.fn(), +})); + +import { updateEditorCodeFromPython } from "@/core/codemirror/language/utils"; + +function createMockEditorView(code: string): EditorView { + return new EditorView({ + state: EditorState.create({ + doc: code, + extensions: [ + adaptiveLanguageConfiguration({ + cellId: "cell1" as CellId, + completionConfig: { + copilot: false, + activate_on_typing: true, + codeium_api_key: null, + }, + hotkeys: new OverridingHotkeyProvider({}), + placeholderType: "marimo-import", + lspConfig: {}, + }), + cellConfigExtension({ + completionConfig: { + copilot: false, + activate_on_typing: true, + codeium_api_key: null, + }, + hotkeys: new OverridingHotkeyProvider({}), + placeholderType: "marimo-import", + lspConfig: {}, + diagnosticsConfig: {}, + }), + ], + }), + }); +} + +describe("EditNotebookTool", () => { + let store: ReturnType; + let tool: EditNotebookTool; + let cellId1: CellId; + let cellId2: CellId; + let cellId3: CellId; + + beforeEach(() => { + store = getDefaultStore(); + tool = new EditNotebookTool(store); + + cellId1 = "cell-1" as CellId; + cellId2 = "cell-2" as CellId; + cellId3 = "cell-3" as CellId; + + // Reset mocks + vi.clearAllMocks(); + + // Reset atom states + store.set(stagedAICellsAtom, new Map()); + }); + + describe("tool metadata", () => { + it("should have correct metadata", () => { + expect(tool.name).toBe("edit_notebook_tool"); + expect(tool.description).toBeDefined(); + expect(tool.description.baseDescription).toContain("editing operations"); + expect(tool.schema).toBeDefined(); + expect(tool.outputSchema).toBeDefined(); + expect(tool.mode).toEqual(["agent"]); + }); + }); + + describe("update_cell operation", () => { + it("should update cell with new code", async () => { + const oldCode = "x = 1"; + const newCode = "x = 2"; + + // Create notebook state with mock editor view + const editorView = createMockEditorView(oldCode); + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: oldCode }, + }, + }); + notebook.cellHandles[cellId1] = { current: { editorView } } as never; + store.set(notebookAtom, notebook); + + const result = await tool.handler({ + edit: { + type: "update_cell", + cellId: cellId1, + code: newCode, + }, + }); + + expect(result.status).toBe("success"); + expect(vi.mocked(updateEditorCodeFromPython)).toHaveBeenCalledWith( + editorView, + newCode, + ); + + // Check that cell was staged + const stagedCells = store.get(stagedAICellsAtom); + expect(stagedCells.has(cellId1)).toBe(true); + expect(stagedCells.get(cellId1)).toEqual({ + type: "update_cell", + previousCode: oldCode, + }); + }); + + it("should throw error when cell ID doesn't exist", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + }, + }); + store.set(notebookAtom, notebook); + + await expect( + tool.handler({ + edit: { + type: "update_cell", + cellId: "nonexistent" as CellId, + code: "x = 2", + }, + }), + ).rejects.toThrow(ToolExecutionError); + + await expect( + tool.handler({ + edit: { + type: "update_cell", + cellId: "nonexistent" as CellId, + code: "x = 2", + }, + }), + ).rejects.toThrow("Cell not found"); + }); + + it("should throw error when cell editor not found", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + }, + }); + // Don't set editorView + notebook.cellHandles[cellId1] = { current: null } as never; + store.set(notebookAtom, notebook); + + await expect( + tool.handler({ + edit: { + type: "update_cell", + cellId: cellId1, + code: "x = 2", + }, + }), + ).rejects.toThrow("Cell editor not found"); + }); + }); + + describe("add_cell operation", () => { + it("should add cell at the end of the notebook", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + }, + }); + store.set(notebookAtom, notebook); + + const newCode = "y = 2"; + const result = await tool.handler({ + edit: { + type: "add_cell", + position: "__end__", + code: newCode, + }, + }); + + expect(result.status).toBe("success"); + + // Check that a new cell was staged + const stagedCells = store.get(stagedAICellsAtom); + expect(stagedCells.size).toBe(1); + const [cellId, edit] = [...stagedCells.entries()][0]; + expect(edit).toEqual({ type: "add_cell" }); + expect(cellId).toBeDefined(); + }); + + it("should add cell before a specific cell", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + [cellId2]: { code: "x = 2" }, + }, + }); + store.set(notebookAtom, notebook); + + const newCode = "y = 2"; + const result = await tool.handler({ + edit: { + type: "add_cell", + position: { cellId: cellId2, before: true }, + code: newCode, + }, + }); + + expect(result.status).toBe("success"); + + // Check that a new cell was staged + const stagedCells = store.get(stagedAICellsAtom); + expect(stagedCells.size).toBe(1); + const [_cellId, edit] = [...stagedCells.entries()][0]; + expect(edit).toEqual({ type: "add_cell" }); + }); + + it("should add cell after a specific cell", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + [cellId2]: { code: "x = 2" }, + }, + }); + store.set(notebookAtom, notebook); + + const newCode = "y = 2"; + const result = await tool.handler({ + edit: { + type: "add_cell", + position: { cellId: cellId2, before: false }, + code: newCode, + }, + }); + + expect(result.status).toBe("success"); + + // Check that a new cell was staged + const stagedCells = store.get(stagedAICellsAtom); + expect(stagedCells.size).toBe(1); + }); + + it("should add cell at the end of a specific column", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + [cellId2]: { code: "x = 2" }, + }, + }); + // Create multi-column layout + notebook.cellIds = MultiColumn.from([[cellId1], [cellId2]]); + const columnId = notebook.cellIds.getColumns()[1].id; + store.set(notebookAtom, notebook); + + const newCode = "y = 2"; + const result = await tool.handler({ + edit: { + type: "add_cell", + position: { type: "__end__", columnId }, + code: newCode, + }, + }); + + expect(result.status).toBe("success"); + + // Check that a new cell was staged + const stagedCells = store.get(stagedAICellsAtom); + expect(stagedCells.size).toBe(1); + }); + + it("should throw error when cell ID doesn't exist for position", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + }, + }); + store.set(notebookAtom, notebook); + + await expect( + tool.handler({ + edit: { + type: "add_cell", + position: { cellId: "nonexistent" as CellId, before: true }, + code: "y = 2", + }, + }), + ).rejects.toThrow("Cell not found"); + }); + + it("should throw error when column ID doesn't exist", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + }, + }); + store.set(notebookAtom, notebook); + + await expect( + tool.handler({ + edit: { + type: "add_cell", + position: { + type: "__end__", + columnId: "nonexistent" as CellColumnId, + }, + code: "y = 2", + }, + }), + ).rejects.toThrow("Column not found"); + }); + }); + + describe("delete_cell operation", () => { + it("should delete a cell", async () => { + const cellCode = "x = 1"; + const editorView = createMockEditorView(cellCode); + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: cellCode }, + [cellId2]: { code: "x = 2" }, + }, + }); + notebook.cellHandles[cellId1] = { current: { editorView } } as never; + store.set(notebookAtom, notebook); + + const result = await tool.handler({ + edit: { + type: "delete_cell", + cellId: cellId1, + }, + }); + + expect(result.status).toBe("success"); + + // Check that cell was staged for deletion + const stagedCells = store.get(stagedAICellsAtom); + expect(stagedCells.has(cellId1)).toBe(true); + expect(stagedCells.get(cellId1)).toEqual({ + type: "delete_cell", + previousCode: cellCode, + }); + }); + + it("should throw error when cell ID doesn't exist", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + }, + }); + store.set(notebookAtom, notebook); + + await expect( + tool.handler({ + edit: { + type: "delete_cell", + cellId: "nonexistent" as CellId, + }, + }), + ).rejects.toThrow("Cell not found"); + }); + + it("should throw error when cell editor not found", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + }, + }); + // Don't set editorView + notebook.cellHandles[cellId1] = { current: null } as never; + store.set(notebookAtom, notebook); + + await expect( + tool.handler({ + edit: { + type: "delete_cell", + cellId: cellId1, + }, + }), + ).rejects.toThrow("Cell editor not found"); + }); + }); + + describe("validation", () => { + it("should validate cell exists in multi-column notebook", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + [cellId2]: { code: "x = 2" }, + [cellId3]: { code: "x = 3" }, + }, + }); + // Create multi-column layout + notebook.cellIds = MultiColumn.from([[cellId1, cellId2], [cellId3]]); + store.set(notebookAtom, notebook); + + // Should not throw for cells in different columns + const editorView = createMockEditorView("x = 1"); + notebook.cellHandles[cellId1] = { current: { editorView } } as never; + notebook.cellHandles[cellId3] = { current: { editorView } } as never; + + await expect( + tool.handler({ + edit: { + type: "update_cell", + cellId: cellId1, + code: "y = 1", + }, + }), + ).resolves.toBeDefined(); + + await expect( + tool.handler({ + edit: { + type: "update_cell", + cellId: cellId3, + code: "y = 3", + }, + }), + ).resolves.toBeDefined(); + }); + }); + + describe("return value", () => { + it("should return success status with next steps", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + }, + }); + store.set(notebookAtom, notebook); + + const result = await tool.handler({ + edit: { + type: "add_cell", + position: "__end__", + code: "y = 2", + }, + }); + + expect(result.status).toBe("success"); + expect(result.next_steps).toBeDefined(); + expect(Array.isArray(result.next_steps)).toBe(true); + expect(result.next_steps?.length).toBeGreaterThan(0); + }); + }); +}); diff --git a/frontend/src/core/ai/tools/__tests__/utils.test.ts b/frontend/src/core/ai/tools/__tests__/utils.test.ts new file mode 100644 index 00000000000..b194a022e40 --- /dev/null +++ b/frontend/src/core/ai/tools/__tests__/utils.test.ts @@ -0,0 +1,87 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import { describe, expect, it } from "vitest"; +import type { ToolDescription } from "../base"; +import { formatToolDescription } from "../utils"; + +describe("formatToolDescription", () => { + it("formats a basic description only", () => { + const description: ToolDescription = { + baseDescription: "This is a simple tool", + }; + + expect(formatToolDescription(description)).toMatchInlineSnapshot( + `"This is a simple tool"`, + ); + }); + + it("formats description with whenToUse", () => { + const description: ToolDescription = { + baseDescription: "Edit notebook cells", + whenToUse: [ + "When user requests code changes", + "When refactoring is needed", + ], + }; + + expect(formatToolDescription(description)).toMatchInlineSnapshot(` + "Edit notebook cells + + ## When to use: + - When user requests code changes + - When refactoring is needed" + `); + }); + + it("formats description with all fields", () => { + const description: ToolDescription = { + baseDescription: "A comprehensive tool", + whenToUse: ["Use case 1", "Use case 2"], + avoidIf: ["Avoid case 1", "Avoid case 2"], + prerequisites: ["Prerequisite 1", "Prerequisite 2"], + sideEffects: ["Side effect 1", "Side effect 2"], + additionalInfo: "Some extra context and notes.", + }; + + expect(formatToolDescription(description)).toMatchInlineSnapshot(` + "A comprehensive tool + + ## When to use: + - Use case 1 + - Use case 2 + + ## Avoid if: + - Avoid case 1 + - Avoid case 2 + + ## Prerequisites: + - Prerequisite 1 + - Prerequisite 2 + + ## Side effects: + - Side effect 1 + - Side effect 2 + + ## Additional info: + - Some extra context and notes." + `); + }); + + it("formats description with only some fields", () => { + const description: ToolDescription = { + baseDescription: "A selective tool", + avoidIf: ["Don't use in production"], + sideEffects: ["Modifies state"], + }; + + expect(formatToolDescription(description)).toMatchInlineSnapshot(` + "A selective tool + + ## Avoid if: + - Don't use in production + + ## Side effects: + - Modifies state" + `); + }); +}); diff --git a/frontend/src/core/ai/tools/base.ts b/frontend/src/core/ai/tools/base.ts index e81b1b98026..845943e2efe 100644 --- a/frontend/src/core/ai/tools/base.ts +++ b/frontend/src/core/ai/tools/base.ts @@ -83,15 +83,18 @@ export const toolOutputBaseSchema = z.object({ meta: z.record(z.string(), z.unknown()).optional(), }); -/** - * Contract for a frontend tool. - * - * Implementations can be plain objects or classes. The registry consumes this - * interface without caring about the underlying implementation. - */ +export interface ToolDescription { + baseDescription: string; + whenToUse?: string[]; + avoidIf?: string[]; + prerequisites?: string[]; + sideEffects?: string[]; + additionalInfo?: string; +} + export interface AiTool { name: string; - description: string; + description: ToolDescription; schema: z.ZodType; outputSchema: z.ZodType; mode: CopilotMode[]; diff --git a/frontend/src/core/ai/tools/edit-notebook-tool.ts b/frontend/src/core/ai/tools/edit-notebook-tool.ts new file mode 100644 index 00000000000..f9739e15d5f --- /dev/null +++ b/frontend/src/core/ai/tools/edit-notebook-tool.ts @@ -0,0 +1,246 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import type { EditorView } from "@codemirror/view"; +import { z } from "zod"; +import { scrollAndHighlightCell } from "@/components/editor/links/cell-link"; +import { + createNotebookActions, + type CellPosition as NotebookCellPosition, + type NotebookState, + notebookAtom, + notebookReducer, +} from "@/core/cells/cells"; +import { CellId } from "@/core/cells/ids"; +import { updateEditorCodeFromPython } from "@/core/codemirror/language/utils"; +import type { JotaiStore } from "@/core/state/jotai"; +import type { CellColumnId } from "@/utils/id-tree"; +import { + createStagedAICellsActions, + stagedAICellsAtom, + stagedAICellsReducer, +} from "../staged-cells"; +import { + type AiTool, + type ToolDescription, + ToolExecutionError, + type ToolOutputBase, + toolOutputBaseSchema, +} from "./base"; +import type { CopilotMode } from "./registry"; + +const description: ToolDescription = { + baseDescription: + "Perform editing operations on the current notebook. Call this tool multiple times to perform multiple edits.", + prerequisites: [ + "Find out the cellIds and columnIds first (call lightweight cell map tool)", + ], + additionalInfo: ` + Args: + edit (object): The editing operation to perform. Must be one of: + - update_cell: Update the code of an existing cell, pass CellId and the new code. + - add_cell: Add a new cell to the notebook. The position of the new cell is specified by the position argument. + Pass "__end__" to add the new cell at the end of the notebook. + Pass { cellId: cellId, before: true } to add the new cell before the specified cell. And before: false if after the specified cell. + Pass { type: "__end__", columnId: columnId } to add the new cell at the end of the specified column. + - delete_cell: Delete an existing cell, pass CellId. For deleting cells, the user needs to accept the deletion to actually delete the cell, so you may still see the cell in the notebook on subsequent edits which is fine. + + For adding code, use the following guidelines: + - Markdown cells: use mo.md(f"""{content}""") function to insert content. + - SQL cells: use mo.sql(f"""{content}""") function to insert content. If a database engine is specified, use mo.sql(f"""{content}""", engine=engine) instead. + + Returns: + - A result object containing standard tool metadata.`, +}; + +type CellPosition = + | { cellId: CellId; before: boolean } + | { type: "__end__"; columnId: CellColumnId } + | "__end__"; + +const editNotebookSchema = z.object({ + edit: z.discriminatedUnion("type", [ + z.object({ + type: z.literal("update_cell"), + cellId: z.string() as unknown as z.ZodType, + code: z.string(), + }), + z.object({ + type: z.literal("add_cell"), + position: z.union([ + z.object({ + cellId: z.string() as unknown as z.ZodType, + before: z.boolean(), + }), + z.object({ + type: z.literal("__end__"), + columnId: z.string() as unknown as z.ZodType, + }), + z.literal("__end__"), + ]) satisfies z.ZodType, + code: z.string(), + }), + z.object({ + type: z.literal("delete_cell"), + cellId: z.string() as unknown as z.ZodType, + }), + ]), +}); + +type EditNotebookInput = z.infer; +type EditOperation = EditNotebookInput["edit"]; +export type EditType = EditOperation["type"]; + +export class EditNotebookTool + implements AiTool +{ + private readonly store: JotaiStore; + private readonly notebookActions: ReturnType; + private readonly stagedAICellsActions: ReturnType< + typeof createStagedAICellsActions + >; + readonly name = "edit_notebook_tool"; + readonly description = description; + readonly schema = editNotebookSchema; + readonly outputSchema = toolOutputBaseSchema; + readonly mode: CopilotMode[] = ["agent"]; + + constructor(store: JotaiStore) { + this.store = store; + this.notebookActions = createNotebookActions((action) => { + this.store.set(notebookAtom, (state) => notebookReducer(state, action)); + }); + this.stagedAICellsActions = createStagedAICellsActions((action) => { + this.store.set(stagedAICellsAtom, (state) => + stagedAICellsReducer(state, action), + ); + }); + } + + handler = async ({ edit }: EditNotebookInput): Promise => { + switch (edit.type) { + case "update_cell": { + const { cellId, code } = edit; + + const notebook = this.store.get(notebookAtom); + this.validateCellIdExists(cellId, notebook); + const editorView = this.getCellEditorView(cellId, notebook); + + scrollAndHighlightCell(cellId); + + const currentCellCode = editorView.state.doc.toString(); + this.stagedAICellsActions.addStagedCell({ + cellId, + edit: { type: "update_cell", previousCode: currentCellCode }, + }); + + updateEditorCodeFromPython(editorView, code); + + break; + } + case "add_cell": { + const { position, code } = edit; + + // By default, add the new cell to the end of the notebook + let notebookPosition: NotebookCellPosition = "__end__"; + let before = false; + const newCellId = CellId.create(); + + if (typeof position === "object") { + const notebook = this.store.get(notebookAtom); + if ("cellId" in position) { + this.validateCellIdExists(position.cellId, notebook); + notebookPosition = position.cellId; + before = position.before; + } else if ("columnId" in position) { + this.validateColumnIdExists(position.columnId, notebook); + notebookPosition = { type: "__end__", columnId: position.columnId }; + } + } + + this.notebookActions.createNewCell({ + cellId: notebookPosition, + before, + code, + newCellId, + }); + + // Add to staged AICells + this.stagedAICellsActions.addStagedCell({ + cellId: newCellId, + edit: { type: "add_cell" }, + }); + + // Scroll into view + scrollAndHighlightCell(newCellId); + + break; + } + case "delete_cell": { + const { cellId } = edit; + + const notebook = this.store.get(notebookAtom); + this.validateCellIdExists(cellId, notebook); + + const editorView = this.getCellEditorView(cellId, notebook); + const currentCellCode = editorView.state.doc.toString(); + + // Add to staged AICells - don't actually delete the cell yet + this.stagedAICellsActions.addStagedCell({ + cellId, + edit: { type: "delete_cell", previousCode: currentCellCode }, + }); + + scrollAndHighlightCell(cellId); + break; + } + } + return { + status: "success", + next_steps: ["If you need to perform more edits, call this tool again."], + }; + }; + + private validateCellIdExists(cellId: CellId, notebook: NotebookState) { + const cellIds = notebook.cellIds; + if (!cellIds.getColumns().some((column) => column.idSet.has(cellId))) { + throw new ToolExecutionError( + "Cell not found", + "CELL_NOT_FOUND", + false, + "Check which cells exist in the notebook", + ); + } + } + + private validateColumnIdExists( + columnId: CellColumnId, + notebook: NotebookState, + ) { + const cellIds = notebook.cellIds; + if (!cellIds.getColumns().some((column) => column.id === columnId)) { + throw new ToolExecutionError( + "Column not found", + "COLUMN_NOT_FOUND", + false, + "Check which columns exist in the notebook", + ); + } + } + + private getCellEditorView( + cellId: CellId, + notebook: NotebookState, + ): EditorView { + const cellHandles = notebook.cellHandles; + const cellHandle = cellHandles[cellId].current; + if (!cellHandle?.editorView) { + throw new ToolExecutionError( + "Cell editor not found", + "CELL_EDITOR_NOT_FOUND", + false, + "Internal error, ask the user to report this error", + ); + } + return cellHandle.editorView; + } +} diff --git a/frontend/src/core/ai/tools/registry.ts b/frontend/src/core/ai/tools/registry.ts index 8897b6051fb..72b4b756d79 100644 --- a/frontend/src/core/ai/tools/registry.ts +++ b/frontend/src/core/ai/tools/registry.ts @@ -3,8 +3,10 @@ import type { components } from "@marimo-team/marimo-api"; import { Memoize } from "typescript-memoize"; import { type ZodObject, z } from "zod"; +import { store } from "@/core/state/jotai"; import { type AiTool, ToolExecutionError } from "./base"; -import { TestFrontendTool } from "./sample-tool"; +import { EditNotebookTool } from "./edit-notebook-tool"; +import { formatToolDescription } from "./utils"; export type AnyZodObject = ZodObject; @@ -114,7 +116,7 @@ export class FrontendToolRegistry { getToolSchemas(): FrontendToolDefinition[] { return [...this.tools.values()].map((tool) => ({ name: tool.name, - description: tool.description, + description: formatToolDescription(tool.description), parameters: z.toJSONSchema(tool.schema), source: "frontend", mode: tool.mode, @@ -123,6 +125,6 @@ export class FrontendToolRegistry { } export const FRONTEND_TOOL_REGISTRY = new FrontendToolRegistry([ - ...(import.meta.env.DEV ? [new TestFrontendTool()] : []), + new EditNotebookTool(store), // ADD MORE TOOLS HERE ]); diff --git a/frontend/src/core/ai/tools/sample-tool.ts b/frontend/src/core/ai/tools/sample-tool.ts index 8e1f93a52e4..b81fef49f8e 100644 --- a/frontend/src/core/ai/tools/sample-tool.ts +++ b/frontend/src/core/ai/tools/sample-tool.ts @@ -3,21 +3,23 @@ import { z } from "zod"; import { type AiTool, + type ToolDescription, ToolExecutionError, type ToolOutputBase, toolOutputBaseSchema, } from "./base"; import type { CopilotMode } from "./registry"; -const description = ` -Test frontend tool that returns a greeting message. +const description: ToolDescription = { + baseDescription: "Test frontend tool that returns a greeting message.", + additionalInfo: ` + Args: + - name (string): The name to include in the greeting. -Args: -- name (string): The name to include in the greeting. - -Returns: -- Output with data containing the greeting message. -`; + Returns: + - Output with data containing the greeting message. + `, +}; interface Input { name: string; diff --git a/frontend/src/core/ai/tools/utils.ts b/frontend/src/core/ai/tools/utils.ts new file mode 100644 index 00000000000..f8aca9fbeb7 --- /dev/null +++ b/frontend/src/core/ai/tools/utils.ts @@ -0,0 +1,23 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import type { ToolDescription } from "./base"; + +export function formatToolDescription(description: ToolDescription): string { + let result = description.baseDescription; + if (description.whenToUse) { + result += `\n\n## When to use:\n- ${description.whenToUse.join("\n- ")}`; + } + if (description.avoidIf) { + result += `\n\n## Avoid if:\n- ${description.avoidIf.join("\n- ")}`; + } + if (description.prerequisites) { + result += `\n\n## Prerequisites:\n- ${description.prerequisites.join("\n- ")}`; + } + if (description.sideEffects) { + result += `\n\n## Side effects:\n- ${description.sideEffects.join("\n- ")}`; + } + if (description.additionalInfo) { + result += `\n\n## Additional info:\n- ${description.additionalInfo}`; + } + return result; +} diff --git a/frontend/src/core/cells/cells.ts b/frontend/src/core/cells/cells.ts index 5fd8e220ba1..c7519fdc608 100644 --- a/frontend/src/core/cells/cells.ts +++ b/frontend/src/core/cells/cells.ts @@ -145,13 +145,18 @@ export function initialNotebookState(): NotebookState { }); } +/** The target cell ID to create a new cell relative to. Can be: + * - A CellId string for an existing cell + * - "__end__" to append at the end of the first column + * - {type: "__end__", columnId} to append at the end of a specific column + */ +export type CellPosition = + | CellId + | "__end__" + | { type: "__end__"; columnId: CellColumnId }; + export interface CreateNewCellAction { - /** The target cell ID to create a new cell relative to. Can be: - * - A CellId string for an existing cell - * - "__end__" to append at the end of the first column - * - {type: "__end__", columnId} to append at the end of a specific column - */ - cellId: CellId | "__end__" | { type: "__end__"; columnId: CellColumnId }; + cellId: CellPosition; /** Whether to insert before (true) or after (false) the target cell */ before: boolean; /** Initial code content for the new cell */ diff --git a/frontend/src/core/config/config-schema.ts b/frontend/src/core/config/config-schema.ts index c513b4ef5c6..4e486e57a72 100644 --- a/frontend/src/core/config/config-schema.ts +++ b/frontend/src/core/config/config-schema.ts @@ -44,6 +44,9 @@ export const DEFAULT_AI_MODEL = "openai/gpt-4o"; */ const AUTO_DOWNLOAD_FORMATS = ["html", "markdown", "ipynb"] as const; +const COPILOT_MODES = ["manual", "ask", "agent"] as const; +export type CopilotMode = (typeof COPILOT_MODES)[number]; + const AiConfigSchema = z .object({ api_key: z.string().optional(), @@ -151,7 +154,7 @@ export const UserConfigSchema = z ai: z .looseObject({ rules: z.string().prefault(""), - mode: z.enum(["manual", "ask"]).prefault("manual"), + mode: z.enum(COPILOT_MODES).prefault("manual"), inline_tooltip: z.boolean().prefault(false), open_ai: AiConfigSchema.optional(), anthropic: AiConfigSchema.optional(), diff --git a/frontend/src/css/app/Cell.css b/frontend/src/css/app/Cell.css index e75e965f154..8a090ef3ef4 100644 --- a/frontend/src/css/app/Cell.css +++ b/frontend/src/css/app/Cell.css @@ -34,6 +34,37 @@ z-index: 1; } + /* Styling for cells marked for deletion by AI */ + &:has(.mo-ai-deleted-cell) { + opacity: 0.7; + position: relative; + + .output-area, + .cm-gutters, + .cm { + background-color: var(--red-2); + } + + /* Add a red glow to indicate deletion */ + &::before { + content: ""; + position: absolute; + inset: 0; + border-radius: 10px; + pointer-events: none; + opacity: 0.6; + box-shadow: 0px 0px 10px 0px var(--red-9); + z-index: 1; + } + + /* Add strikethrough effect to the code */ + .cm-content { + text-decoration: line-through; + text-decoration-color: var(--red-9); + text-decoration-thickness: 1px; + } + } + &:focus-within { z-index: 20; } diff --git a/frontend/src/utils/__tests__/arrays.test.ts b/frontend/src/utils/__tests__/arrays.test.ts index 47da4be94ae..d5108086492 100644 --- a/frontend/src/utils/__tests__/arrays.test.ts +++ b/frontend/src/utils/__tests__/arrays.test.ts @@ -8,6 +8,7 @@ import { arrayMove, arrayShallowEquals, arrayToggle, + getNextIndex, } from "../arrays"; describe("arrays", () => { @@ -128,3 +129,44 @@ describe("arrays", () => { }); }); }); + +describe("getNextIndex", () => { + it("should return 0 if listLength is 0", () => { + expect(getNextIndex(null, 0, "up")).toBe(0); + expect(getNextIndex(null, 0, "down")).toBe(0); + expect(getNextIndex(0, 0, "up")).toBe(0); + expect(getNextIndex(0, 0, "down")).toBe(0); + }); + + it("should return 0 if currentIndex is null and direction is up", () => { + expect(getNextIndex(null, 10, "up")).toBe(0); + }); + + it("should return last index if currentIndex is null and direction is down", () => { + expect(getNextIndex(null, 10, "down")).toBe(9); + }); + + it("should return the next index in the list", () => { + expect(getNextIndex(0, 10, "up")).toBe(1); + }); + + it("should wrap around to the start of the list", () => { + expect(getNextIndex(9, 10, "up")).toBe(0); + }); + + it("should return next index in middle of list", () => { + expect(getNextIndex(5, 10, "up")).toBe(6); + }); + + it("should return previous index in middle of list", () => { + expect(getNextIndex(5, 10, "down")).toBe(4); + }); + + it("should wrap around to the end of the list", () => { + expect(getNextIndex(0, 10, "down")).toBe(9); + }); + + it("should return the previous index in the list", () => { + expect(getNextIndex(1, 10, "down")).toBe(0); + }); +}); diff --git a/frontend/src/utils/arrays.ts b/frontend/src/utils/arrays.ts index b3107806b02..f1fa2015380 100644 --- a/frontend/src/utils/arrays.ts +++ b/frontend/src/utils/arrays.ts @@ -86,3 +86,28 @@ export function uniqueBy(arr: T[], key: (item: T) => string): T[] { } return result; } + +/** + * Get the next index in the list, wrapping around to the start or end if necessary. + * @param currentIndex - The current index, or null if there is no current index. + * @param listLength - The length of the list. + * @param direction - The direction to move in. + * @returns The next index. + */ +export function getNextIndex( + currentIndex: number | null, + listLength: number, + direction: "up" | "down", +): number { + if (listLength === 0) { + return 0; + } + + if (currentIndex === null) { + return direction === "up" ? 0 : listLength - 1; + } + + return direction === "up" + ? (currentIndex + 1) % listLength + : (currentIndex - 1 + listLength) % listLength; +} diff --git a/marimo/_config/config.py b/marimo/_config/config.py index 72923a325d2..5802e60eb21 100644 --- a/marimo/_config/config.py +++ b/marimo/_config/config.py @@ -231,7 +231,7 @@ class PackageManagementConfig(TypedDict): manager: Literal["pip", "rye", "uv", "poetry", "pixi"] -CopilotMode = Literal["ask", "manual"] +CopilotMode = Literal["ask", "manual", "agent"] @mddoc diff --git a/marimo/_server/ai/prompts.py b/marimo/_server/ai/prompts.py index cc8da16e471..67d6df137e8 100644 --- a/marimo/_server/ai/prompts.py +++ b/marimo/_server/ai/prompts.py @@ -244,6 +244,14 @@ def _get_mode_intro_message(mode: CopilotMode) -> str: "- All tool use is strictly read-only. You may not perform write, edit, or execution actions.\n" "- You must always explain to the user why you are using a tool before invoking it.\n" ) + elif mode == "agent": + return ( + f"{base_intro}" + "## Capabilities\n" + "- You can use a set of read and write tools to gather additional context from the notebook or environment (e.g., searching code, summarizing data, or reading documentation) and to modify the notebook (e.g., adding cells, editing cells, deleting cells).\n" + "## Limitations\n" + "- You must always explain to the user why you are using a tool before invoking it.\n" + ) def _get_session_info(session_id: SessionId) -> str: @@ -381,13 +389,27 @@ def get_chat_system_prompt( chart """ - for language in language_rules: - if len(language_rules[language]) == 0: - continue + if mode == "agent": + # In agent-mode, we can add how to insert cells into the notebook. + for lang in LANGUAGES: + # check if multiple_cells rules are present for this language + rule_to_add = language_rules_multiple_cells.get( + lang, language_rules.get(lang, []) + ) + if rule_to_add: + system_prompt += ( + f"\n\n## Rules for {lang}:\n{_rules(rule_to_add)}" + ) + + system_prompt += "\n\n## Rules for inserting cells:\n" + system_prompt += 'For markdown cells, use `mo.md(f"""{content}""")`\n' + system_prompt += 'For sql cells, use `mo.sql(f"""{content}""")`. If a database engine is specified, use `mo.sql(f"""{content}""", engine=engine)` instead.\n' + else: + for language in language_rules: + if len(language_rules[language]) == 0: + continue - system_prompt += ( - f"\n\n## Rules for {language}:\n{_rules(language_rules[language])}" - ) + system_prompt += f"\n\n## Rules for {language}:\n{_rules(language_rules[language])}" if custom_rules and custom_rules.strip(): system_prompt += f"\n\n## Additional rules:\n{custom_rules}" diff --git a/marimo/_server/ai/tools/tool_manager.py b/marimo/_server/ai/tools/tool_manager.py index cc616057e96..11f6d75734d 100644 --- a/marimo/_server/ai/tools/tool_manager.py +++ b/marimo/_server/ai/tools/tool_manager.py @@ -53,7 +53,8 @@ def _register_backend_tool(self, tool: ToolBase[Any, Any]) -> None: """Register a backend tool with its handler function and optional validator.""" name = tool.name tool_definition, validation_function = tool.as_backend_tool( - mode=["ask"] + # TODO: Tools should define their own supported modes + mode=["ask", "agent"] ) self._tools[name] = tool_definition self._backend_handlers[name] = tool.__call__ @@ -175,8 +176,9 @@ def _convert_mcp_tool(self, mcp_tool: MCPRawTool) -> ToolDefinition: description=mcp_tool.description or "No description available", parameters=mcp_tool.inputSchema, source="mcp", - mode=["ask"], # MCP tools available in ask mode for now - # TODO(bjoaquinc): change default mode to "agent" when we add agent mode + # MCP tools available in ask mode and agent mode + # TODO: Determine which tools to support in agent mode + mode=["ask", "agent"], ) async def invoke_tool( diff --git a/packages/openapi/api.yaml b/packages/openapi/api.yaml index e6ccafe9749..9cd5bded4d9 100644 --- a/packages/openapi/api.yaml +++ b/packages/openapi/api.yaml @@ -114,6 +114,7 @@ components: type: integer mode: enum: + - agent - ask - manual models: @@ -3403,6 +3404,7 @@ components: mode: items: enum: + - agent - ask - manual type: array diff --git a/packages/openapi/src/api.ts b/packages/openapi/src/api.ts index aaef4207ea8..5c5fce39b87 100644 --- a/packages/openapi/src/api.ts +++ b/packages/openapi/src/api.ts @@ -2847,7 +2847,7 @@ export interface components { inline_tooltip?: boolean; max_tokens?: number; /** @enum {unknown} */ - mode?: "ask" | "manual"; + mode?: "agent" | "ask" | "manual"; models?: components["schemas"]["AiModelConfig"]; ollama?: components["schemas"]["OpenAiConfig"]; open_ai?: components["schemas"]["OpenAiConfig"]; @@ -4614,7 +4614,7 @@ export interface components { */ ToolDefinition: { description: string; - mode: ("ask" | "manual")[]; + mode: ("agent" | "ask" | "manual")[]; name: string; parameters: Record; /** @enum {unknown} */ diff --git a/tests/_server/ai/snapshots/chat_system_prompts.txt b/tests/_server/ai/snapshots/chat_system_prompts.txt index 69f3377a28e..99c3d312a9c 100644 --- a/tests/_server/ai/snapshots/chat_system_prompts.txt +++ b/tests/_server/ai/snapshots/chat_system_prompts.txt @@ -889,6 +889,153 @@ import pandas as pd import numpy as np +==================== with agent mode ==================== + + +You are Marimo Copilot, an AI assistant integrated into the marimo notebook code editor. +Your primary function is to help users create, analyze, and improve data science notebooks using marimo's reactive programming model. +## Capabilities +- You can use a set of read and write tools to gather additional context from the notebook or environment (e.g., searching code, summarizing data, or reading documentation) and to modify the notebook (e.g., adding cells, editing cells, deleting cells). +## Limitations +- You must always explain to the user why you are using a tool before invoking it. + +Current notebook session ID: s_test. Use this session_id with tools that require it. + +Your goal is to do one of the following two things: + +1. Help users answer questions related to their notebook. +2. Answer general-purpose questions unrelated to their particular notebook. + +It will be up to you to decide which of these you are doing based on what the user has told you. When unclear, ask clarifying questions to understand the user's intent before proceeding. + +The user may reference additional context in the form @kind://name. You can use this context to help you with the current task. + +You can respond with markdown, code, or a combination of both. You only work with two languages: Python and SQL. +When responding in code, think of each block of code as a separate cell in the notebook. + +You have the following rules: + +- Do not import the same library twice. +- Do not define a variable if it already exists. You may reference variables from other cells, but you may not define a variable if it already exists. + +# Marimo fundamentals + +Marimo is a reactive notebook that differs from traditional notebooks in key ways: +- Cells execute automatically when their dependencies change +- Variables cannot be redeclared across cells +- The notebook forms a directed acyclic graph (DAG) +- The last expression in a cell is automatically displayed +- UI elements are reactive and update the notebook automatically + +Marimo's reactivity means: +- When a variable changes, all cells that use that variable automatically re-execute +- UI elements trigger updates when their values change without explicit callbacks +- UI element values are accessed through `.value` attribute +- You cannot access a UI element's value in the same cell where it's defined + +## Best Practices + + +- Use polars for data manipulation +- Implement proper data validation +- Handle missing values appropriately +- Use efficient data structures +- A variable in the last expression of a cell is automatically displayed as a table + + + +- Access UI element values with .value attribute (e.g., slider.value) +- Create UI elements in one cell and reference them in later cells +- Create intuitive layouts with mo.hstack(), mo.vstack(), and mo.tabs() +- Prefer reactive updates over callbacks (marimo handles reactivity automatically) +- Group related UI elements for better organization + + +## Available UI elements + +* `mo.ui.altair_chart(altair_chart)` - create a reactive Altair chart +* `mo.ui.button(value=None, kind='primary')` - create a clickable button +* `mo.ui.run_button(label=None, tooltip=None, kind='primary')` - create a button that runs code +* `mo.ui.checkbox(label='', value=False)` - create a checkbox +* `mo.ui.chat(placeholder='', value=None)` - create a chat interface +* `mo.ui.date(value=None, label=None, full_width=False)` - create a date picker +* `mo.ui.dropdown(options, value=None, label=None, full_width=False)` - create a dropdown menu +* `mo.ui.file(label='', multiple=False, full_width=False)` - create a file upload element +* `mo.ui.number(value=None, label=None, full_width=False)` - create a number input +* `mo.ui.radio(options, value=None, label=None, full_width=False)` - create radio buttons +* `mo.ui.refresh(options: List[str], default_interval: str)` - create a refresh control +* `mo.ui.slider(start, stop, value=None, label=None, full_width=False, step=None)` - create a slider +* `mo.ui.range_slider(start, stop, value=None, label=None, full_width=False, step=None)` - create a range slider +* `mo.ui.table(data, columns=None, on_select=None, sortable=True, filterable=True)` - create an interactive table +* `mo.ui.text(value='', label=None, full_width=False)` - create a text input +* `mo.ui.text_area(value='', label=None, full_width=False)` - create a multi-line text input +* `mo.ui.data_explorer(df)` - create an interactive dataframe explorer +* `mo.ui.dataframe(df)` - display a dataframe with search, filter, and sort capabilities +* `mo.ui.plotly(plotly_figure)` - create a reactive Plotly chart (supports scatter, treemap, and sunburst) +* `mo.ui.tabs(elements: dict[str, mo.ui.Element])` - create a tabbed interface from a dictionary +* `mo.ui.array(elements: list[mo.ui.Element])` - create an array of UI elements +* `mo.ui.form(element: mo.ui.Element, label='', bordered=True)` - wrap an element in a form + +## Layout and utility functions + +* `mo.stop(predicate, output=None)` - stop execution conditionally +* `mo.Html(html)` - display HTML +* `mo.image(image)` - display an image +* `mo.hstack(elements)` - stack elements horizontally +* `mo.vstack(elements)` - stack elements vertically +* `mo.tabs(elements)` - create a tabbed interface +* `mo.mpl.interactive()` - make matplotlib plots interactive + +## Examples + + +import marimo as mo +import altair as alt +import polars as pl +import numpy as np + +# Create a slider and display it +n_points = mo.ui.slider(10, 100, value=50, label="Number of points") +n_points # Display the slider + +# Generate random data based on slider value +# This cell automatically re-executes when n_points.value changes +x = np.random.rand(n_points.value) +y = np.random.rand(n_points.value) + +df = pl.DataFrame({"x": x, "y": y}) + +chart = alt.Chart(df).mark_circle(opacity=0.7).encode( + x=alt.X('x', title='X axis'), + y=alt.Y('y', title='Y axis') +).properties( + title=f"Scatter plot with {n_points.value} points", + width=400, + height=300 +) + +chart + + +## Rules for python: +1. For matplotlib: use plt.gca() as the last expression instead of plt.show(). +2. For plotly: return the figure object directly. +3. For altair: return the chart object directly. Add tooltips where appropriate. You can pass polars dataframes directly to altair (e.g., alt.Chart(df)). +4. Include proper labels, titles, and color schemes. +5. Make visualizations interactive where appropriate. +6. If an import already exists, do not import it again. +7. If a variable is already defined, use another name, or make it private by adding an underscore at the beginning. + +## Rules for sql: +1. SQL cells start with df = mo.sql(f"""""") for DuckDB, or df = mo.sql(f"""""", engine=engine) for other SQL engines. +2. This will automatically display the result in the UI. You do not need to return the dataframe in the cell. +3. The SQL must use the syntax of the database engine specified in the `engine` variable. If no engine, then use duckdb syntax. + +## Rules for inserting cells: +For markdown cells, use `mo.md(f"""{content}""")` +For sql cells, use `mo.sql(f"""{content}""")`. If a database engine is specified, use `mo.sql(f"""{content}""", engine=engine)` instead. + + ==================== kitchen sink ==================== diff --git a/tests/_server/ai/test_prompts.py b/tests/_server/ai/test_prompts.py index 0e52a34dcd3..e0a24b196fb 100644 --- a/tests/_server/ai/test_prompts.py +++ b/tests/_server/ai/test_prompts.py @@ -311,6 +311,15 @@ def test_chat_system_prompts(): session_id=SessionId("s_test"), ) + result += _header("with agent mode") + result += get_chat_system_prompt( + custom_rules=None, + include_other_code="", + context=None, + mode="agent", + session_id=SessionId("s_test"), + ) + result += _header("kitchen sink") result += get_chat_system_prompt( custom_rules="Always be polite.",