From 4996e4fb4bf4f88bff7625800959a6e9380b4c39 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Wed, 15 Oct 2025 18:12:42 +0800 Subject: [PATCH 01/10] wip --- frontend/src/components/chat/chat-panel.tsx | 16 +- .../editor/ai/ai-completion-editor.tsx | 43 ++++-- .../components/editor/cell/StagedAICell.tsx | 29 +++- .../editor/cell/code/cell-editor.tsx | 5 +- frontend/src/core/ai/config.ts | 8 +- frontend/src/core/ai/staged-cells.ts | 25 +++- .../src/core/ai/tools/edit-notebook-tool.ts | 120 +++++++++++++++ frontend/src/core/config/config-schema.ts | 5 +- marimo/_config/config.py | 2 +- marimo/_server/ai/prompts.py | 8 + marimo/_server/ai/tools/tool_manager.py | 8 +- packages/openapi/api.yaml | 2 + packages/openapi/src/api.ts | 4 +- .../ai/snapshots/chat_system_prompts.txt | 140 ++++++++++++++++++ tests/_server/ai/test_prompts.py | 9 ++ 15 files changed, 392 insertions(+), 32 deletions(-) create mode 100644 frontend/src/core/ai/tools/edit-notebook-tool.ts diff --git a/frontend/src/components/chat/chat-panel.tsx b/frontend/src/components/chat/chat-panel.tsx index 6d223f12d6d..79d5f1974ec 100644 --- a/frontend/src/components/chat/chat-panel.tsx +++ b/frontend/src/components/chat/chat-panel.tsx @@ -36,7 +36,10 @@ import { type ChatId, chatStateAtom, } from "@/core/ai/state"; -import { FRONTEND_TOOL_REGISTRY } from "@/core/ai/tools/registry"; +import { + type CopilotMode, + FRONTEND_TOOL_REGISTRY, +} from "@/core/ai/tools/registry"; import { aiAtom, aiEnabledAtom } from "@/core/config/config"; import { DEFAULT_AI_MODEL } from "@/core/config/config-schema"; import { FeatureFlagged } from "@/core/config/feature-flag"; @@ -301,7 +304,11 @@ const ChatInputFooter: React.FC = memo( const { saveModeChange } = useModelChange(); - const modeOptions = [ + const modeOptions: { + value: CopilotMode; + label: string; + subtitle: string; + }[] = [ { value: "ask", label: "Ask", @@ -313,6 +320,11 @@ const ChatInputFooter: React.FC = memo( label: "Manual", subtitle: "Pure chat, no tool usage", }, + { + value: "agent", + label: "Agent", + subtitle: "Use AI with access to read and write tools", + }, ]; const isAttachmentSupported = diff --git a/frontend/src/components/editor/ai/ai-completion-editor.tsx b/frontend/src/components/editor/ai/ai-completion-editor.tsx index 58c104fa2e0..5d5b07cfe56 100644 --- a/frontend/src/components/editor/ai/ai-completion-editor.tsx +++ b/frontend/src/components/editor/ai/ai-completion-editor.tsx @@ -11,14 +11,16 @@ import { customPythonLanguageSupport } from "@/core/codemirror/language/language import "./merge-editor.css"; import { storePrompt } from "@marimo-team/codemirror-ai"; import type { ReactCodeMirrorRef } from "@uiw/react-codemirror"; -import { useAtom } from "jotai"; +import { useAtom, useAtomValue } from "jotai"; import { AIModelDropdown } from "@/components/ai/ai-model-dropdown"; import { Checkbox } from "@/components/ui/checkbox"; import { Label } from "@/components/ui/label"; import { Switch } from "@/components/ui/switch"; import { Tooltip } from "@/components/ui/tooltip"; import { toast } from "@/components/ui/use-toast"; -import { includeOtherCellsAtom } from "@/core/ai/state"; +import { stagedAICellsAtom } from "@/core/ai/staged-cells"; +import { type AiCompletionCell, includeOtherCellsAtom } from "@/core/ai/state"; +import type { CellId } from "@/core/cells/ids"; import { getCodes } from "@/core/codemirror/copilot/getCodes"; import type { LanguageAdapterType } from "@/core/codemirror/language/types"; import { selectAllText } from "@/core/codemirror/utils"; @@ -40,15 +42,14 @@ const Original = CodeMirrorMerge.Original; const Modified = CodeMirrorMerge.Modified; interface Props { + cellId: CellId; + aiCompletionCell: AiCompletionCell | null; className?: string; currentCode: string; currentLanguageAdapter: LanguageAdapterType | undefined; - initialPrompt: string | undefined; onChange: (code: string) => void; declineChange: () => void; acceptChange: (rightHandCode: string) => void; - enabled: boolean; - triggerImmediately?: boolean; runCell: () => void; outputArea?: "above" | "below"; /** @@ -65,15 +66,14 @@ const baseExtensions = [customPythonLanguageSupport(), EditorView.lineWrapping]; * This shows a left/right split with the original and modified code. */ export const AiCompletionEditor: React.FC = ({ + cellId, + aiCompletionCell, className, onChange, - initialPrompt, currentLanguageAdapter, currentCode, declineChange, acceptChange, - enabled, - triggerImmediately, runCell, outputArea, children, @@ -88,6 +88,20 @@ export const AiCompletionEditor: React.FC = ({ const runtimeManager = useRuntimeManager(); + const { + initialPrompt, + triggerImmediately, + cellId: aiCellId, + } = aiCompletionCell ?? {}; + const enabled = aiCellId === cellId; + + const stagedAICells = useAtomValue(stagedAICellsAtom); + const updatedCell = stagedAICells.get(cellId); + let previousCellCode: string | undefined; + if (updatedCell?.type === "update_cell") { + previousCellCode = updatedCell.previousCode; + } + const { completion: untrimmedCompletion, input, @@ -333,7 +347,18 @@ export const AiCompletionEditor: React.FC = ({ /> )} - {(!completion || !enabled) && children} + {previousCellCode && ( + + + + + )} + {(!completion || !enabled) && !previousCellCode && children} {/* By default, show the completion banner below the code */} {(outputArea === "below" || !outputArea) && completionBanner} diff --git a/frontend/src/components/editor/cell/StagedAICell.tsx b/frontend/src/components/editor/cell/StagedAICell.tsx index 0806fb9d3a7..ece458a9696 100644 --- a/frontend/src/components/editor/cell/StagedAICell.tsx +++ b/frontend/src/components/editor/cell/StagedAICell.tsx @@ -2,8 +2,11 @@ import { useAtomValue, useStore } from "jotai"; import { stagedAICellsAtom, useStagedCells } from "@/core/ai/staged-cells"; +import { cellHandleAtom } from "@/core/cells/cells"; import type { CellId } from "@/core/cells/ids"; +import { updateEditorCodeFromPython } from "@/core/codemirror/language/utils"; import { cn } from "@/utils/cn"; +import { Logger } from "@/utils/Logger"; import { CompletionActionsCellFooter } from "../ai/completion-handlers"; export const StagedAICellBackground: React.FC<{ @@ -24,9 +27,11 @@ export const StagedAICellFooter: React.FC<{ cellId: CellId }> = ({ }) => { const store = useStore(); const stagedAICells = useAtomValue(stagedAICellsAtom); + const stagedAiCell = stagedAICells.get(cellId); + const { deleteStagedCell, removeStagedCell } = useStagedCells(store); - if (!stagedAICells.has(cellId)) { + if (!stagedAiCell) { return null; } @@ -35,7 +40,27 @@ export const StagedAICellFooter: React.FC<{ cellId: CellId }> = ({ }; const handleDeclineCompletion = () => { - deleteStagedCell(cellId); + switch (stagedAiCell.type) { + case "update_cell": { + // Revert cell code + const cellHandle = store.get(cellHandleAtom(cellId)); + const editorView = cellHandle?.current?.editorView; + if (!editorView) { + Logger.error("Editor for this cell not found", { cellId }); + break; + } + + updateEditorCodeFromPython(editorView, stagedAiCell.previousCode); + removeStagedCell(cellId); + break; + } + case "add_cell": + // Delete the cell since it's newly created + deleteStagedCell(cellId); + break; + case "delete_cell": + break; + } }; return ( diff --git a/frontend/src/components/editor/cell/code/cell-editor.tsx b/frontend/src/components/editor/cell/code/cell-editor.tsx index daddecf26db..f1036459ffc 100644 --- a/frontend/src/components/editor/cell/code/cell-editor.tsx +++ b/frontend/src/components/editor/cell/code/cell-editor.tsx @@ -407,9 +407,8 @@ const CellEditorInternal = ({ return ( { diff --git a/frontend/src/core/ai/config.ts b/frontend/src/core/ai/config.ts index d7f1abde916..c78a31a063b 100644 --- a/frontend/src/core/ai/config.ts +++ b/frontend/src/core/ai/config.ts @@ -4,7 +4,11 @@ import type { Role } from "@marimo-team/llm-info"; import { useAtom } from "jotai"; import type { QualifiedModelId } from "@/core/ai/ids/ids"; import { userConfigAtom } from "@/core/config/config"; -import type { AIModelKey, UserConfig } from "@/core/config/config-schema"; +import type { + AIModelKey, + CopilotMode, + UserConfig, +} from "@/core/config/config-schema"; import { useRequestClient } from "@/core/network/requests"; // Extract only the supported roles from the Role type @@ -60,7 +64,7 @@ export const useModelChange = () => { saveConfig(newConfig); }; - const saveModeChange = async (newMode: "ask" | "manual") => { + const saveModeChange = async (newMode: CopilotMode) => { const newConfig: UserConfig = { ...userConfig, ai: { diff --git a/frontend/src/core/ai/staged-cells.ts b/frontend/src/core/ai/staged-cells.ts index a8e77e3cd95..231eeb6a027 100644 --- a/frontend/src/core/ai/staged-cells.ts +++ b/frontend/src/core/ai/staged-cells.ts @@ -19,15 +19,24 @@ import { import type { LanguageAdapterType } from "../codemirror/language/types"; import { updateEditorCodeFromPython } from "../codemirror/language/utils"; import type { JotaiStore } from "../state/jotai"; +import type { EditType } from "./tools/edit-notebook-tool"; /** * Cells that are staged for AI completion - * They function similarly to cells in the notebook, but they can be deleted or accepted by the user. - * We only track one set of staged cells at a time. + * They function similarly to cells in the notebook, but they can be accepted or rejected by the user. + * We track edited, new and deleted cells. + * And we only track one set of staged cells at a time. */ -const initialState = (): Set => { - return new Set(); +type Edit = + | { type: Extract; previousCode: string } + | { type: Extract } + | { type: Extract; previousCode: string }; + +type StagedAICells = Map; + +const initialState = (): StagedAICells => { + return new Map(); }; const { @@ -38,10 +47,12 @@ const { } = createReducerAndAtoms(initialState, { addStagedCell: (state, action: { cellId: CellId }) => { const { cellId } = action; - return new Set([...state, cellId]); + return new Map([...state, [cellId, { type: "add_cell" }]]); }, removeStagedCell: (state, cellId: CellId) => { - return new Set([...state].filter((id) => id !== cellId)); + const newState = new Map(state); + newState.delete(cellId); + return newState; }, clearStagedCells: () => { return initialState(); @@ -105,7 +116,7 @@ export function useStagedCells(store: JotaiStore) { // Delete all staged cells and the corresponding cells in the notebook. const deleteAllStagedCells = () => { const stagedAICells = store.get(stagedAICellsAtom); - for (const cellId of stagedAICells) { + for (const cellId of stagedAICells.keys()) { deleteCellCallback({ cellId }); } clearStagedCells(); diff --git a/frontend/src/core/ai/tools/edit-notebook-tool.ts b/frontend/src/core/ai/tools/edit-notebook-tool.ts new file mode 100644 index 00000000000..34d58407fe4 --- /dev/null +++ b/frontend/src/core/ai/tools/edit-notebook-tool.ts @@ -0,0 +1,120 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import { z } from "zod"; +import { notebookAtom } from "@/core/cells/cells"; +import type { CellId } from "@/core/cells/ids"; +import { updateEditorCodeFromPython } from "@/core/codemirror/language/utils"; +import type { JotaiStore } from "@/core/state/jotai"; +import { stagedAICellsAtom } from "../staged-cells"; +import { + type AiTool, + ToolExecutionError, + type ToolOutputBase, + toolOutputBaseSchema, +} from "./base"; +import type { CopilotMode } from "./registry"; + +const description = ` +Perform editing operations on the current notebook. +Call this tool multiple times to perform multiple edits. + +Args: +- edit (object): The editing operation to perform. Must be one of: + - update_cell: Update the code of an existing cell. + - add_cell: Add a new cell to the notebook. + - delete_cell: Delete an existing cell. + +Returns: +- A result object containing standard tool metadata. +`; + +const editNotebookSchema = z.object({ + edit: z.discriminatedUnion("type", [ + z.object({ + type: z.literal("update_cell"), + cellId: z.string() as unknown as z.ZodType, + code: z.string(), + }), + z.object({ + type: z.literal("add_cell"), + cellId: z.string() as unknown as z.ZodType, + code: z.string(), + language: z.enum(["python", "sql", "markdown"]).optional(), + }), + z.object({ + type: z.literal("delete_cell"), + cellId: z.string() as unknown as z.ZodType, + }), + ]), +}); + +type EditNotebookInput = z.infer; +type EditOperation = EditNotebookInput["edit"]; +export type EditType = EditOperation["type"]; + +export class EditNotebookTool + implements AiTool +{ + private readonly store: JotaiStore; + readonly name = "edit_notebook_tool"; + readonly description = description; + readonly schema = editNotebookSchema; + readonly outputSchema = toolOutputBaseSchema; + readonly mode: CopilotMode[] = ["agent"]; + + constructor(store: JotaiStore) { + this.store = store; + } + + handler = async ({ edit }: EditNotebookInput): Promise => { + switch (edit.type) { + case "update_cell": { + const { cellId, code } = edit; + + const notebook = this.store.get(notebookAtom); + const cellIds = notebook.cellIds; + if (!cellIds.getColumns().some((column) => column.idSet.has(cellId))) { + throw new ToolExecutionError( + "Cell not found", + "CELL_NOT_FOUND", + false, + "Check which cells exist in the notebook", + ); + } + + const cellHandles = notebook.cellHandles; + const cellHandle = cellHandles[cellId].current; + if (!cellHandle?.editorView) { + throw new ToolExecutionError( + "Cell editor not found", + "CELL_EDITOR_NOT_FOUND", + false, + "Internal error, ask the user to report this error", + ); + } + + const currentCellCode = cellHandle.editorView.state.doc.toString(); + + const stagedAICells = this.store.get(stagedAICellsAtom); + const newStagedAICells = new Map([ + ...stagedAICells, + [cellId, { type: "update_cell", previousCode: currentCellCode }], + ]); + this.store.set(stagedAICellsAtom, newStagedAICells); + + updateEditorCodeFromPython(cellHandle.editorView, code); + + break; + } + case "add_cell": + // TODO + break; + case "delete_cell": + // TODO + break; + } + return { + status: "success", + }; + }; +} diff --git a/frontend/src/core/config/config-schema.ts b/frontend/src/core/config/config-schema.ts index c513b4ef5c6..4e486e57a72 100644 --- a/frontend/src/core/config/config-schema.ts +++ b/frontend/src/core/config/config-schema.ts @@ -44,6 +44,9 @@ export const DEFAULT_AI_MODEL = "openai/gpt-4o"; */ const AUTO_DOWNLOAD_FORMATS = ["html", "markdown", "ipynb"] as const; +const COPILOT_MODES = ["manual", "ask", "agent"] as const; +export type CopilotMode = (typeof COPILOT_MODES)[number]; + const AiConfigSchema = z .object({ api_key: z.string().optional(), @@ -151,7 +154,7 @@ export const UserConfigSchema = z ai: z .looseObject({ rules: z.string().prefault(""), - mode: z.enum(["manual", "ask"]).prefault("manual"), + mode: z.enum(COPILOT_MODES).prefault("manual"), inline_tooltip: z.boolean().prefault(false), open_ai: AiConfigSchema.optional(), anthropic: AiConfigSchema.optional(), diff --git a/marimo/_config/config.py b/marimo/_config/config.py index 72923a325d2..5802e60eb21 100644 --- a/marimo/_config/config.py +++ b/marimo/_config/config.py @@ -231,7 +231,7 @@ class PackageManagementConfig(TypedDict): manager: Literal["pip", "rye", "uv", "poetry", "pixi"] -CopilotMode = Literal["ask", "manual"] +CopilotMode = Literal["ask", "manual", "agent"] @mddoc diff --git a/marimo/_server/ai/prompts.py b/marimo/_server/ai/prompts.py index cc8da16e471..df5dc278972 100644 --- a/marimo/_server/ai/prompts.py +++ b/marimo/_server/ai/prompts.py @@ -244,6 +244,14 @@ def _get_mode_intro_message(mode: CopilotMode) -> str: "- All tool use is strictly read-only. You may not perform write, edit, or execution actions.\n" "- You must always explain to the user why you are using a tool before invoking it.\n" ) + elif mode == "agent": + return ( + f"{base_intro}" + "## Capabilities\n" + "- You can use a set of read and write tools to gather additional context from the notebook or environment (e.g., searching code, summarizing data, or reading documentation) and to modify the notebook (e.g., adding cells, editing cells, deleting cells).\n" + "## Limitations\n" + "- You must always explain to the user why you are using a tool before invoking it.\n" + ) def _get_session_info(session_id: SessionId) -> str: diff --git a/marimo/_server/ai/tools/tool_manager.py b/marimo/_server/ai/tools/tool_manager.py index cc616057e96..11f6d75734d 100644 --- a/marimo/_server/ai/tools/tool_manager.py +++ b/marimo/_server/ai/tools/tool_manager.py @@ -53,7 +53,8 @@ def _register_backend_tool(self, tool: ToolBase[Any, Any]) -> None: """Register a backend tool with its handler function and optional validator.""" name = tool.name tool_definition, validation_function = tool.as_backend_tool( - mode=["ask"] + # TODO: Tools should define their own supported modes + mode=["ask", "agent"] ) self._tools[name] = tool_definition self._backend_handlers[name] = tool.__call__ @@ -175,8 +176,9 @@ def _convert_mcp_tool(self, mcp_tool: MCPRawTool) -> ToolDefinition: description=mcp_tool.description or "No description available", parameters=mcp_tool.inputSchema, source="mcp", - mode=["ask"], # MCP tools available in ask mode for now - # TODO(bjoaquinc): change default mode to "agent" when we add agent mode + # MCP tools available in ask mode and agent mode + # TODO: Determine which tools to support in agent mode + mode=["ask", "agent"], ) async def invoke_tool( diff --git a/packages/openapi/api.yaml b/packages/openapi/api.yaml index e6ccafe9749..9cd5bded4d9 100644 --- a/packages/openapi/api.yaml +++ b/packages/openapi/api.yaml @@ -114,6 +114,7 @@ components: type: integer mode: enum: + - agent - ask - manual models: @@ -3403,6 +3404,7 @@ components: mode: items: enum: + - agent - ask - manual type: array diff --git a/packages/openapi/src/api.ts b/packages/openapi/src/api.ts index aaef4207ea8..5c5fce39b87 100644 --- a/packages/openapi/src/api.ts +++ b/packages/openapi/src/api.ts @@ -2847,7 +2847,7 @@ export interface components { inline_tooltip?: boolean; max_tokens?: number; /** @enum {unknown} */ - mode?: "ask" | "manual"; + mode?: "agent" | "ask" | "manual"; models?: components["schemas"]["AiModelConfig"]; ollama?: components["schemas"]["OpenAiConfig"]; open_ai?: components["schemas"]["OpenAiConfig"]; @@ -4614,7 +4614,7 @@ export interface components { */ ToolDefinition: { description: string; - mode: ("ask" | "manual")[]; + mode: ("agent" | "ask" | "manual")[]; name: string; parameters: Record; /** @enum {unknown} */ diff --git a/tests/_server/ai/snapshots/chat_system_prompts.txt b/tests/_server/ai/snapshots/chat_system_prompts.txt index 69f3377a28e..076475eec3a 100644 --- a/tests/_server/ai/snapshots/chat_system_prompts.txt +++ b/tests/_server/ai/snapshots/chat_system_prompts.txt @@ -889,6 +889,146 @@ import pandas as pd import numpy as np +==================== with agent mode ==================== + + +You are Marimo Copilot, an AI assistant integrated into the marimo notebook code editor. +Your primary function is to help users create, analyze, and improve data science notebooks using marimo's reactive programming model. +## Capabilities +- You can use a set of read and write tools to gather additional context from the notebook or environment (e.g., searching code, summarizing data, or reading documentation) and to modify the notebook (e.g., adding cells, editing cells, deleting cells). +## Limitations +- You must always explain to the user why you are using a tool before invoking it. + +Current notebook session ID: s_test. Use this session_id with tools that require it. + +Your goal is to do one of the following two things: + +1. Help users answer questions related to their notebook. +2. Answer general-purpose questions unrelated to their particular notebook. + +It will be up to you to decide which of these you are doing based on what the user has told you. When unclear, ask clarifying questions to understand the user's intent before proceeding. + +The user may reference additional context in the form @kind://name. You can use this context to help you with the current task. + +You can respond with markdown, code, or a combination of both. You only work with two languages: Python and SQL. +When responding in code, think of each block of code as a separate cell in the notebook. + +You have the following rules: + +- Do not import the same library twice. +- Do not define a variable if it already exists. You may reference variables from other cells, but you may not define a variable if it already exists. + +# Marimo fundamentals + +Marimo is a reactive notebook that differs from traditional notebooks in key ways: +- Cells execute automatically when their dependencies change +- Variables cannot be redeclared across cells +- The notebook forms a directed acyclic graph (DAG) +- The last expression in a cell is automatically displayed +- UI elements are reactive and update the notebook automatically + +Marimo's reactivity means: +- When a variable changes, all cells that use that variable automatically re-execute +- UI elements trigger updates when their values change without explicit callbacks +- UI element values are accessed through `.value` attribute +- You cannot access a UI element's value in the same cell where it's defined + +## Best Practices + + +- Use polars for data manipulation +- Implement proper data validation +- Handle missing values appropriately +- Use efficient data structures +- A variable in the last expression of a cell is automatically displayed as a table + + + +- Access UI element values with .value attribute (e.g., slider.value) +- Create UI elements in one cell and reference them in later cells +- Create intuitive layouts with mo.hstack(), mo.vstack(), and mo.tabs() +- Prefer reactive updates over callbacks (marimo handles reactivity automatically) +- Group related UI elements for better organization + + +## Available UI elements + +* `mo.ui.altair_chart(altair_chart)` - create a reactive Altair chart +* `mo.ui.button(value=None, kind='primary')` - create a clickable button +* `mo.ui.run_button(label=None, tooltip=None, kind='primary')` - create a button that runs code +* `mo.ui.checkbox(label='', value=False)` - create a checkbox +* `mo.ui.chat(placeholder='', value=None)` - create a chat interface +* `mo.ui.date(value=None, label=None, full_width=False)` - create a date picker +* `mo.ui.dropdown(options, value=None, label=None, full_width=False)` - create a dropdown menu +* `mo.ui.file(label='', multiple=False, full_width=False)` - create a file upload element +* `mo.ui.number(value=None, label=None, full_width=False)` - create a number input +* `mo.ui.radio(options, value=None, label=None, full_width=False)` - create radio buttons +* `mo.ui.refresh(options: List[str], default_interval: str)` - create a refresh control +* `mo.ui.slider(start, stop, value=None, label=None, full_width=False, step=None)` - create a slider +* `mo.ui.range_slider(start, stop, value=None, label=None, full_width=False, step=None)` - create a range slider +* `mo.ui.table(data, columns=None, on_select=None, sortable=True, filterable=True)` - create an interactive table +* `mo.ui.text(value='', label=None, full_width=False)` - create a text input +* `mo.ui.text_area(value='', label=None, full_width=False)` - create a multi-line text input +* `mo.ui.data_explorer(df)` - create an interactive dataframe explorer +* `mo.ui.dataframe(df)` - display a dataframe with search, filter, and sort capabilities +* `mo.ui.plotly(plotly_figure)` - create a reactive Plotly chart (supports scatter, treemap, and sunburst) +* `mo.ui.tabs(elements: dict[str, mo.ui.Element])` - create a tabbed interface from a dictionary +* `mo.ui.array(elements: list[mo.ui.Element])` - create an array of UI elements +* `mo.ui.form(element: mo.ui.Element, label='', bordered=True)` - wrap an element in a form + +## Layout and utility functions + +* `mo.stop(predicate, output=None)` - stop execution conditionally +* `mo.Html(html)` - display HTML +* `mo.image(image)` - display an image +* `mo.hstack(elements)` - stack elements horizontally +* `mo.vstack(elements)` - stack elements vertically +* `mo.tabs(elements)` - create a tabbed interface +* `mo.mpl.interactive()` - make matplotlib plots interactive + +## Examples + + +import marimo as mo +import altair as alt +import polars as pl +import numpy as np + +# Create a slider and display it +n_points = mo.ui.slider(10, 100, value=50, label="Number of points") +n_points # Display the slider + +# Generate random data based on slider value +# This cell automatically re-executes when n_points.value changes +x = np.random.rand(n_points.value) +y = np.random.rand(n_points.value) + +df = pl.DataFrame({"x": x, "y": y}) + +chart = alt.Chart(df).mark_circle(opacity=0.7).encode( + x=alt.X('x', title='X axis'), + y=alt.Y('y', title='Y axis') +).properties( + title=f"Scatter plot with {n_points.value} points", + width=400, + height=300 +) + +chart + + +## Rules for python: +1. For matplotlib: use plt.gca() as the last expression instead of plt.show(). +2. For plotly: return the figure object directly. +3. For altair: return the chart object directly. Add tooltips where appropriate. You can pass polars dataframes directly to altair (e.g., alt.Chart(df)). +4. Include proper labels, titles, and color schemes. +5. Make visualizations interactive where appropriate. +6. If an import already exists, do not import it again. +7. If a variable is already defined, use another name, or make it private by adding an underscore at the beginning. + +## Rules for sql: +1. The SQL must use duckdb syntax. + ==================== kitchen sink ==================== diff --git a/tests/_server/ai/test_prompts.py b/tests/_server/ai/test_prompts.py index 0e52a34dcd3..e0a24b196fb 100644 --- a/tests/_server/ai/test_prompts.py +++ b/tests/_server/ai/test_prompts.py @@ -311,6 +311,15 @@ def test_chat_system_prompts(): session_id=SessionId("s_test"), ) + result += _header("with agent mode") + result += get_chat_system_prompt( + custom_rules=None, + include_other_code="", + context=None, + mode="agent", + session_id=SessionId("s_test"), + ) + result += _header("kitchen sink") result += get_chat_system_prompt( custom_rules="Always be polite.", From b34f6e2dab6d0194c0b80c642d54d311ac6490dd Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Thu, 16 Oct 2025 17:22:00 +0800 Subject: [PATCH 02/10] handling of new cells and edit cells --- .../editor/ai/ai-completion-editor.tsx | 60 ++--- .../components/editor/cell/StagedAICell.tsx | 6 +- frontend/src/core/ai/staged-cells.ts | 24 +- .../src/core/ai/tools/__tests__/utils.test.ts | 87 +++++++ frontend/src/core/ai/tools/base.ts | 17 +- .../src/core/ai/tools/edit-notebook-tool.ts | 215 ++++++++++++++---- frontend/src/core/ai/tools/registry.ts | 8 +- frontend/src/core/ai/tools/sample-tool.ts | 18 +- frontend/src/core/ai/tools/utils.ts | 23 ++ frontend/src/core/cells/cells.ts | 17 +- 10 files changed, 364 insertions(+), 111 deletions(-) create mode 100644 frontend/src/core/ai/tools/__tests__/utils.test.ts create mode 100644 frontend/src/core/ai/tools/utils.ts diff --git a/frontend/src/components/editor/ai/ai-completion-editor.tsx b/frontend/src/components/editor/ai/ai-completion-editor.tsx index 5d5b07cfe56..f02f500fa89 100644 --- a/frontend/src/components/editor/ai/ai-completion-editor.tsx +++ b/frontend/src/components/editor/ai/ai-completion-editor.tsx @@ -192,6 +192,8 @@ export const AiCompletionEditor: React.FC = ({ const showCompletionBanner = enabled && triggerImmediately && (completion || isLoading); + // Set default output area to below if not specified + outputArea = outputArea === undefined ? "below" : outputArea; const showInput = enabled && (!triggerImmediately || showInputPrompt); @@ -216,6 +218,35 @@ export const AiCompletionEditor: React.FC = ({ ); + const renderMergeEditor = (originalCode: string, modifiedCode: string) => { + return ( + + + + + ); + }; + + const renderCompletionEditor = () => { + if (completion && enabled) { + return renderMergeEditor(currentCode, completion); + } + // If there is no completion and there is previous cell code, it means there is an AI change to the cell. + // And we want to render the previous cell code as the original + if (!completion && previousCellCode) { + return renderMergeEditor(previousCellCode, currentCode); + } + }; + return (
= ({ )}
{outputArea === "above" && completionBanner} - {completion && enabled && ( - - - - - )} - {previousCellCode && ( - - - - - )} + {renderCompletionEditor()} {(!completion || !enabled) && !previousCellCode && children} {/* By default, show the completion banner below the code */} - {(outputArea === "below" || !outputArea) && completionBanner} + {outputArea === "below" && completionBanner}
); }; diff --git a/frontend/src/components/editor/cell/StagedAICell.tsx b/frontend/src/components/editor/cell/StagedAICell.tsx index ece458a9696..070879da12a 100644 --- a/frontend/src/components/editor/cell/StagedAICell.tsx +++ b/frontend/src/components/editor/cell/StagedAICell.tsx @@ -2,7 +2,7 @@ import { useAtomValue, useStore } from "jotai"; import { stagedAICellsAtom, useStagedCells } from "@/core/ai/staged-cells"; -import { cellHandleAtom } from "@/core/cells/cells"; +import { getCellEditorView } from "@/core/cells/cells"; import type { CellId } from "@/core/cells/ids"; import { updateEditorCodeFromPython } from "@/core/codemirror/language/utils"; import { cn } from "@/utils/cn"; @@ -43,8 +43,7 @@ export const StagedAICellFooter: React.FC<{ cellId: CellId }> = ({ switch (stagedAiCell.type) { case "update_cell": { // Revert cell code - const cellHandle = store.get(cellHandleAtom(cellId)); - const editorView = cellHandle?.current?.editorView; + const editorView = getCellEditorView(cellId); if (!editorView) { Logger.error("Editor for this cell not found", { cellId }); break; @@ -59,6 +58,7 @@ export const StagedAICellFooter: React.FC<{ cellId: CellId }> = ({ deleteStagedCell(cellId); break; case "delete_cell": + // TODO: Revert delete break; } }; diff --git a/frontend/src/core/ai/staged-cells.ts b/frontend/src/core/ai/staged-cells.ts index 231eeb6a027..5a2dd656c61 100644 --- a/frontend/src/core/ai/staged-cells.ts +++ b/frontend/src/core/ai/staged-cells.ts @@ -13,7 +13,7 @@ import { Logger } from "@/utils/Logger"; import { maybeAddMarimoImport } from "../cells/add-missing-import"; import { type CreateNewCellAction, - cellHandleAtom, + getCellEditorView, useCellActions, } from "../cells/cells"; import type { LanguageAdapterType } from "../codemirror/language/types"; @@ -45,9 +45,9 @@ const { createActions, reducer, } = createReducerAndAtoms(initialState, { - addStagedCell: (state, action: { cellId: CellId }) => { - const { cellId } = action; - return new Map([...state, [cellId, { type: "add_cell" }]]); + addStagedCell: (state, action: { cellId: CellId; edit: Edit }) => { + const { cellId, edit } = action; + return new Map([...state, [cellId, edit]]); }, removeStagedCell: (state, cellId: CellId) => { const newState = new Map(state); @@ -59,6 +59,11 @@ const { }, }); +export { + createActions as createStagedAICellsActions, + reducer as stagedAICellsReducer, +}; + interface UpdateStagedCellAction { cellId: CellId; code: string; @@ -78,7 +83,7 @@ export function useStagedCells(store: JotaiStore) { const createStagedCell = (code: string): CellId => { const newCellId = CellId.create(); - addStagedCell({ cellId: newCellId }); + addStagedCell({ cellId: newCellId, edit: { type: "add_cell" } }); createNewCell({ cellId: "__end__", code, @@ -97,8 +102,7 @@ export function useStagedCells(store: JotaiStore) { return; } - const cellHandle = store.get(cellHandleAtom(cellId)); - const editorView = cellHandle?.current?.editorViewOrNull; + const editorView = getCellEditorView(cellId); if (!editorView) { Logger.error("Editor for this cell not found", { cellId }); return; @@ -186,14 +190,14 @@ class CellCreationStream { private onCreateCell: (code: string) => CellId; private onUpdateCell: (opts: UpdateStagedCellAction) => void; - private addStagedCell: (payload: { cellId: CellId }) => void; + private addStagedCell: (payload: { cellId: CellId; edit: Edit }) => void; private createNewCell: (opts: CreateNewCellAction) => void; private hasMarimoImport = false; constructor( onCreateCell: (code: string) => CellId, onUpdateCell: (opts: UpdateStagedCellAction) => void, - addStagedCell: (payload: { cellId: CellId }) => void, + addStagedCell: (payload: { cellId: CellId; edit: Edit }) => void, createNewCell: (opts: CreateNewCellAction) => void, ) { this.onCreateCell = onCreateCell; @@ -240,7 +244,7 @@ class CellCreationStream { before: true, }); if (cellId) { - this.addStagedCell({ cellId }); + this.addStagedCell({ cellId, edit: { type: "add_cell" } }); } this.hasMarimoImport = true; } diff --git a/frontend/src/core/ai/tools/__tests__/utils.test.ts b/frontend/src/core/ai/tools/__tests__/utils.test.ts new file mode 100644 index 00000000000..b194a022e40 --- /dev/null +++ b/frontend/src/core/ai/tools/__tests__/utils.test.ts @@ -0,0 +1,87 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import { describe, expect, it } from "vitest"; +import type { ToolDescription } from "../base"; +import { formatToolDescription } from "../utils"; + +describe("formatToolDescription", () => { + it("formats a basic description only", () => { + const description: ToolDescription = { + baseDescription: "This is a simple tool", + }; + + expect(formatToolDescription(description)).toMatchInlineSnapshot( + `"This is a simple tool"`, + ); + }); + + it("formats description with whenToUse", () => { + const description: ToolDescription = { + baseDescription: "Edit notebook cells", + whenToUse: [ + "When user requests code changes", + "When refactoring is needed", + ], + }; + + expect(formatToolDescription(description)).toMatchInlineSnapshot(` + "Edit notebook cells + + ## When to use: + - When user requests code changes + - When refactoring is needed" + `); + }); + + it("formats description with all fields", () => { + const description: ToolDescription = { + baseDescription: "A comprehensive tool", + whenToUse: ["Use case 1", "Use case 2"], + avoidIf: ["Avoid case 1", "Avoid case 2"], + prerequisites: ["Prerequisite 1", "Prerequisite 2"], + sideEffects: ["Side effect 1", "Side effect 2"], + additionalInfo: "Some extra context and notes.", + }; + + expect(formatToolDescription(description)).toMatchInlineSnapshot(` + "A comprehensive tool + + ## When to use: + - Use case 1 + - Use case 2 + + ## Avoid if: + - Avoid case 1 + - Avoid case 2 + + ## Prerequisites: + - Prerequisite 1 + - Prerequisite 2 + + ## Side effects: + - Side effect 1 + - Side effect 2 + + ## Additional info: + - Some extra context and notes." + `); + }); + + it("formats description with only some fields", () => { + const description: ToolDescription = { + baseDescription: "A selective tool", + avoidIf: ["Don't use in production"], + sideEffects: ["Modifies state"], + }; + + expect(formatToolDescription(description)).toMatchInlineSnapshot(` + "A selective tool + + ## Avoid if: + - Don't use in production + + ## Side effects: + - Modifies state" + `); + }); +}); diff --git a/frontend/src/core/ai/tools/base.ts b/frontend/src/core/ai/tools/base.ts index e81b1b98026..845943e2efe 100644 --- a/frontend/src/core/ai/tools/base.ts +++ b/frontend/src/core/ai/tools/base.ts @@ -83,15 +83,18 @@ export const toolOutputBaseSchema = z.object({ meta: z.record(z.string(), z.unknown()).optional(), }); -/** - * Contract for a frontend tool. - * - * Implementations can be plain objects or classes. The registry consumes this - * interface without caring about the underlying implementation. - */ +export interface ToolDescription { + baseDescription: string; + whenToUse?: string[]; + avoidIf?: string[]; + prerequisites?: string[]; + sideEffects?: string[]; + additionalInfo?: string; +} + export interface AiTool { name: string; - description: string; + description: ToolDescription; schema: z.ZodType; outputSchema: z.ZodType; mode: CopilotMode[]; diff --git a/frontend/src/core/ai/tools/edit-notebook-tool.ts b/frontend/src/core/ai/tools/edit-notebook-tool.ts index 34d58407fe4..dad18ef7e23 100644 --- a/frontend/src/core/ai/tools/edit-notebook-tool.ts +++ b/frontend/src/core/ai/tools/edit-notebook-tool.ts @@ -1,32 +1,57 @@ /* Copyright 2024 Marimo. All rights reserved. */ +import type { EditorView } from "@codemirror/view"; import { z } from "zod"; -import { notebookAtom } from "@/core/cells/cells"; -import type { CellId } from "@/core/cells/ids"; +import { scrollAndHighlightCell } from "@/components/editor/links/cell-link"; +import { + createNotebookActions, + type CellPosition as NotebookCellPosition, + type NotebookState, + notebookAtom, + notebookReducer, +} from "@/core/cells/cells"; +import { CellId } from "@/core/cells/ids"; import { updateEditorCodeFromPython } from "@/core/codemirror/language/utils"; import type { JotaiStore } from "@/core/state/jotai"; -import { stagedAICellsAtom } from "../staged-cells"; +import type { CellColumnId } from "@/utils/id-tree"; +import { + createStagedAICellsActions, + stagedAICellsAtom, + stagedAICellsReducer, +} from "../staged-cells"; import { type AiTool, + type ToolDescription, ToolExecutionError, type ToolOutputBase, toolOutputBaseSchema, } from "./base"; import type { CopilotMode } from "./registry"; -const description = ` -Perform editing operations on the current notebook. -Call this tool multiple times to perform multiple edits. +const description: ToolDescription = { + baseDescription: + "Perform editing operations on the current notebook. Call this tool multiple times to perform multiple edits.", + prerequisites: [ + "Find out the cellIds and columnIds first (call get_lightweight_cell_map tool)", + ], + additionalInfo: ` + Args: + edit (object): The editing operation to perform. Must be one of: + - update_cell: Update the code of an existing cell, pass CellId and the new code. + - add_cell: Add a new cell to the notebook. The position of the new cell is specified by the position argument. + Pass "__end__" to add the new cell at the end of the notebook. + Pass { cellId: cellId, before: true } to add the new cell before the specified cell. And before: false if after the specified cell. + Pass { type: "__end__", columnId: columnId } to add the new cell at the end of the specified column. + - delete_cell: Delete an existing cell, pass CellId. -Args: -- edit (object): The editing operation to perform. Must be one of: - - update_cell: Update the code of an existing cell. - - add_cell: Add a new cell to the notebook. - - delete_cell: Delete an existing cell. + Returns: + - A result object containing standard tool metadata.`, +}; -Returns: -- A result object containing standard tool metadata. -`; +type CellPosition = + | { cellId: CellId; before: boolean } + | { type: "__end__"; columnId: CellColumnId } + | "__end__"; const editNotebookSchema = z.object({ edit: z.discriminatedUnion("type", [ @@ -37,9 +62,18 @@ const editNotebookSchema = z.object({ }), z.object({ type: z.literal("add_cell"), - cellId: z.string() as unknown as z.ZodType, + position: z.union([ + z.object({ + cellId: z.string() as unknown as z.ZodType, + before: z.boolean(), + }), + z.object({ + type: z.literal("__end__"), + columnId: z.string() as unknown as z.ZodType, + }), + z.literal("__end__"), + ]) satisfies z.ZodType, code: z.string(), - language: z.enum(["python", "sql", "markdown"]).optional(), }), z.object({ type: z.literal("delete_cell"), @@ -56,6 +90,10 @@ export class EditNotebookTool implements AiTool { private readonly store: JotaiStore; + private readonly notebookActions: ReturnType; + private readonly stagedAICellsActions: ReturnType< + typeof createStagedAICellsActions + >; readonly name = "edit_notebook_tool"; readonly description = description; readonly schema = editNotebookSchema; @@ -64,6 +102,14 @@ export class EditNotebookTool constructor(store: JotaiStore) { this.store = store; + this.notebookActions = createNotebookActions((action) => { + this.store.set(notebookAtom, (state) => notebookReducer(state, action)); + }); + this.stagedAICellsActions = createStagedAICellsActions((action) => { + this.store.set(stagedAICellsAtom, (state) => + stagedAICellsReducer(state, action), + ); + }); } handler = async ({ edit }: EditNotebookInput): Promise => { @@ -72,49 +118,124 @@ export class EditNotebookTool const { cellId, code } = edit; const notebook = this.store.get(notebookAtom); - const cellIds = notebook.cellIds; - if (!cellIds.getColumns().some((column) => column.idSet.has(cellId))) { - throw new ToolExecutionError( - "Cell not found", - "CELL_NOT_FOUND", - false, - "Check which cells exist in the notebook", - ); - } + this.validateCellIdExists(cellId, notebook); + const editorView = this.getCellEditorView(cellId, notebook); + + scrollAndHighlightCell(cellId); + + const currentCellCode = editorView.state.doc.toString(); + this.stagedAICellsActions.addStagedCell({ + cellId, + edit: { type: "update_cell", previousCode: currentCellCode }, + }); + + updateEditorCodeFromPython(editorView, code); - const cellHandles = notebook.cellHandles; - const cellHandle = cellHandles[cellId].current; - if (!cellHandle?.editorView) { - throw new ToolExecutionError( - "Cell editor not found", - "CELL_EDITOR_NOT_FOUND", - false, - "Internal error, ask the user to report this error", - ); + break; + } + case "add_cell": { + const { position, code } = edit; + + // By default, add the new cell to the end of the notebook + let notebookPosition: NotebookCellPosition = "__end__"; + let before = false; + const newCellId = CellId.create(); + + if (typeof position === "object") { + const notebook = this.store.get(notebookAtom); + if ("cellId" in position) { + this.validateCellIdExists(position.cellId, notebook); + notebookPosition = position.cellId; + before = position.before; + } else if ("columnId" in position) { + this.validateColumnIdExists(position.columnId, notebook); + notebookPosition = { type: "__end__", columnId: position.columnId }; + } } - const currentCellCode = cellHandle.editorView.state.doc.toString(); + this.notebookActions.createNewCell({ + cellId: notebookPosition, + before, + code, + newCellId, + }); - const stagedAICells = this.store.get(stagedAICellsAtom); - const newStagedAICells = new Map([ - ...stagedAICells, - [cellId, { type: "update_cell", previousCode: currentCellCode }], - ]); - this.store.set(stagedAICellsAtom, newStagedAICells); + // Add to staged AICells + this.stagedAICellsActions.addStagedCell({ + cellId: newCellId, + edit: { type: "add_cell" }, + }); - updateEditorCodeFromPython(cellHandle.editorView, code); + // Scroll into view + scrollAndHighlightCell(newCellId); break; } - case "add_cell": - // TODO - break; - case "delete_cell": - // TODO + case "delete_cell": { + const { cellId } = edit; + const notebook = this.store.get(notebookAtom); + this.validateCellIdExists(cellId, notebook); + + const editorView = this.getCellEditorView(cellId, notebook); + const currentCellCode = editorView.state.doc.toString(); + + // Add to staged AICells + this.stagedAICellsActions.addStagedCell({ + cellId, + edit: { type: "delete_cell", previousCode: currentCellCode }, + }); + + this.notebookActions.deleteCell({ cellId }); break; + } } return { status: "success", + next_steps: ["If you need to perform more edits, call this tool again."], }; }; + + private validateCellIdExists(cellId: CellId, notebook: NotebookState) { + const cellIds = notebook.cellIds; + if (!cellIds.getColumns().some((column) => column.idSet.has(cellId))) { + throw new ToolExecutionError( + "Cell not found", + "CELL_NOT_FOUND", + false, + "Check which cells exist in the notebook", + ); + } + } + + private validateColumnIdExists( + columnId: CellColumnId, + notebook: NotebookState, + ) { + const cellIds = notebook.cellIds; + if (!cellIds.getColumns().some((column) => column.id === columnId)) { + throw new ToolExecutionError( + "Column not found", + "COLUMN_NOT_FOUND", + false, + "Check which columns exist in the notebook", + ); + } + } + + private getCellEditorView( + cellId: CellId, + notebook: NotebookState, + ): EditorView { + const cellHandles = notebook.cellHandles; + const cellHandle = cellHandles[cellId].current; + if (!cellHandle?.editorView) { + throw new ToolExecutionError( + "Cell editor not found", + "CELL_EDITOR_NOT_FOUND", + false, + "Internal error, ask the user to report this error", + ); + } + return cellHandle.editorView; + } } diff --git a/frontend/src/core/ai/tools/registry.ts b/frontend/src/core/ai/tools/registry.ts index 8897b6051fb..72b4b756d79 100644 --- a/frontend/src/core/ai/tools/registry.ts +++ b/frontend/src/core/ai/tools/registry.ts @@ -3,8 +3,10 @@ import type { components } from "@marimo-team/marimo-api"; import { Memoize } from "typescript-memoize"; import { type ZodObject, z } from "zod"; +import { store } from "@/core/state/jotai"; import { type AiTool, ToolExecutionError } from "./base"; -import { TestFrontendTool } from "./sample-tool"; +import { EditNotebookTool } from "./edit-notebook-tool"; +import { formatToolDescription } from "./utils"; export type AnyZodObject = ZodObject; @@ -114,7 +116,7 @@ export class FrontendToolRegistry { getToolSchemas(): FrontendToolDefinition[] { return [...this.tools.values()].map((tool) => ({ name: tool.name, - description: tool.description, + description: formatToolDescription(tool.description), parameters: z.toJSONSchema(tool.schema), source: "frontend", mode: tool.mode, @@ -123,6 +125,6 @@ export class FrontendToolRegistry { } export const FRONTEND_TOOL_REGISTRY = new FrontendToolRegistry([ - ...(import.meta.env.DEV ? [new TestFrontendTool()] : []), + new EditNotebookTool(store), // ADD MORE TOOLS HERE ]); diff --git a/frontend/src/core/ai/tools/sample-tool.ts b/frontend/src/core/ai/tools/sample-tool.ts index 8e1f93a52e4..b81fef49f8e 100644 --- a/frontend/src/core/ai/tools/sample-tool.ts +++ b/frontend/src/core/ai/tools/sample-tool.ts @@ -3,21 +3,23 @@ import { z } from "zod"; import { type AiTool, + type ToolDescription, ToolExecutionError, type ToolOutputBase, toolOutputBaseSchema, } from "./base"; import type { CopilotMode } from "./registry"; -const description = ` -Test frontend tool that returns a greeting message. +const description: ToolDescription = { + baseDescription: "Test frontend tool that returns a greeting message.", + additionalInfo: ` + Args: + - name (string): The name to include in the greeting. -Args: -- name (string): The name to include in the greeting. - -Returns: -- Output with data containing the greeting message. -`; + Returns: + - Output with data containing the greeting message. + `, +}; interface Input { name: string; diff --git a/frontend/src/core/ai/tools/utils.ts b/frontend/src/core/ai/tools/utils.ts new file mode 100644 index 00000000000..f8aca9fbeb7 --- /dev/null +++ b/frontend/src/core/ai/tools/utils.ts @@ -0,0 +1,23 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import type { ToolDescription } from "./base"; + +export function formatToolDescription(description: ToolDescription): string { + let result = description.baseDescription; + if (description.whenToUse) { + result += `\n\n## When to use:\n- ${description.whenToUse.join("\n- ")}`; + } + if (description.avoidIf) { + result += `\n\n## Avoid if:\n- ${description.avoidIf.join("\n- ")}`; + } + if (description.prerequisites) { + result += `\n\n## Prerequisites:\n- ${description.prerequisites.join("\n- ")}`; + } + if (description.sideEffects) { + result += `\n\n## Side effects:\n- ${description.sideEffects.join("\n- ")}`; + } + if (description.additionalInfo) { + result += `\n\n## Additional info:\n- ${description.additionalInfo}`; + } + return result; +} diff --git a/frontend/src/core/cells/cells.ts b/frontend/src/core/cells/cells.ts index 5fd8e220ba1..c7519fdc608 100644 --- a/frontend/src/core/cells/cells.ts +++ b/frontend/src/core/cells/cells.ts @@ -145,13 +145,18 @@ export function initialNotebookState(): NotebookState { }); } +/** The target cell ID to create a new cell relative to. Can be: + * - A CellId string for an existing cell + * - "__end__" to append at the end of the first column + * - {type: "__end__", columnId} to append at the end of a specific column + */ +export type CellPosition = + | CellId + | "__end__" + | { type: "__end__"; columnId: CellColumnId }; + export interface CreateNewCellAction { - /** The target cell ID to create a new cell relative to. Can be: - * - A CellId string for an existing cell - * - "__end__" to append at the end of the first column - * - {type: "__end__", columnId} to append at the end of a specific column - */ - cellId: CellId | "__end__" | { type: "__end__"; columnId: CellColumnId }; + cellId: CellPosition; /** Whether to insert before (true) or after (false) the target cell */ before: boolean; /** Initial code content for the new cell */ From afa15df01ae59fa5905ae3151aac830bbfe27345 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Thu, 16 Oct 2025 20:06:41 +0800 Subject: [PATCH 03/10] fix tests --- .../core/ai/__tests__/staged-cells.test.ts | 254 +++++++++++++----- 1 file changed, 190 insertions(+), 64 deletions(-) diff --git a/frontend/src/core/ai/__tests__/staged-cells.test.ts b/frontend/src/core/ai/__tests__/staged-cells.test.ts index 54a047c41a4..4a1b0ff5b67 100644 --- a/frontend/src/core/ai/__tests__/staged-cells.test.ts +++ b/frontend/src/core/ai/__tests__/staged-cells.test.ts @@ -35,6 +35,7 @@ vi.mock("../../cells/cells", () => ({ cellHandleAtom: vi.fn(() => ({ read: vi.fn(() => mockCellHandle), })), + getCellEditorView: vi.fn(() => mockCellHandle.current.editorViewOrNull), })); vi.mock("@/components/editor/cell/useDeleteCell", () => ({ @@ -66,32 +67,77 @@ describe("staged-cells", () => { vi.clearAllMocks(); // Reset the atom state - store.set(stagedAICellsAtom, new Set()); + store.set(stagedAICellsAtom, new Map()); }); describe("reducer and actions", () => { it("should initialize with empty map", () => { const state = initialState(); - expect(state).toEqual(new Set()); + expect(state).toEqual(new Map()); }); - it("should add cell IDs", () => { + it("should add cells with update_cell edit", () => { let state = initialState(); state = reducer(state, { type: "addStagedCell", - payload: { cellId: cellId1 }, + payload: { + cellId: cellId1, + edit: { type: "update_cell", previousCode: "old code 1" }, + }, }); state = reducer(state, { type: "addStagedCell", - payload: { cellId: cellId2 }, + payload: { + cellId: cellId2, + edit: { type: "update_cell", previousCode: "old code 2" }, + }, }); expect(state.has(cellId1)).toBe(true); expect(state.has(cellId2)).toBe(true); + expect(state.get(cellId1)).toEqual({ + type: "update_cell", + previousCode: "old code 1", + }); + expect(state.get(cellId2)).toEqual({ + type: "update_cell", + previousCode: "old code 2", + }); + }); + + it("should add cells with add_cell edit", () => { + let state = initialState(); + state = reducer(state, { + type: "addStagedCell", + payload: { cellId: cellId1, edit: { type: "add_cell" } }, + }); + + expect(state.has(cellId1)).toBe(true); + expect(state.get(cellId1)).toEqual({ type: "add_cell" }); + }); + + it("should add cells with delete_cell edit", () => { + let state = initialState(); + state = reducer(state, { + type: "addStagedCell", + payload: { + cellId: cellId1, + edit: { type: "delete_cell", previousCode: "deleted code" }, + }, + }); + + expect(state.has(cellId1)).toBe(true); + expect(state.get(cellId1)).toEqual({ + type: "delete_cell", + previousCode: "deleted code", + }); }); it("should remove cell IDs", () => { - const state = new Set([cellId1, cellId2]); + const state = new Map([ + [cellId1, { type: "add_cell" as const }], + [cellId2, { type: "add_cell" as const }], + ]); const newState = reducer(state, { type: "removeStagedCell", payload: cellId1, @@ -101,23 +147,26 @@ describe("staged-cells", () => { expect(newState.has(cellId2)).toBe(true); }); - it("should clear all cell IDs", () => { - const state = new Set([cellId1, cellId2]); + it("should clear all cells", () => { + const state = new Map([ + [cellId1, { type: "add_cell" as const }], + [cellId2, { type: "add_cell" as const }], + ]); const newState = reducer(state, { type: "clearStagedCells", payload: undefined, }); - expect(newState).toEqual(new Set()); + expect(newState).toEqual(new Map()); }); - it("should not mutate original state", () => { - const state = new Set([cellId1]); + it("should not mutate original state when adding", () => { + const state = new Map([[cellId1, { type: "add_cell" as const }]]); const originalSize = state.size; reducer(state, { type: "addStagedCell", - payload: { cellId: cellId2 }, + payload: { cellId: cellId2, edit: { type: "add_cell" } }, }); expect(state.size).toBe(originalSize); @@ -125,6 +174,23 @@ describe("staged-cells", () => { expect(state.has(cellId2)).toBe(false); }); + it("should not mutate original state when removing", () => { + const state = new Map([ + [cellId1, { type: "add_cell" as const }], + [cellId2, { type: "add_cell" as const }], + ]); + const originalSize = state.size; + + reducer(state, { + type: "removeStagedCell", + payload: cellId1, + }); + + expect(state.size).toBe(originalSize); + expect(state.has(cellId1)).toBe(true); + expect(state.has(cellId2)).toBe(true); + }); + it("should create action functions", () => { const mockDispatch = vi.fn(); const actions = createActions(mockDispatch); @@ -136,7 +202,7 @@ describe("staged-cells", () => { it("should initialize atom with empty map", () => { const state = store.get(stagedAICellsAtom); - expect(state).toEqual(new Set()); + expect(state).toEqual(new Map()); }); }); @@ -181,7 +247,13 @@ describe("staged-cells", () => { it("should delete all staged cells when cells exist", () => { // First set the atom state before rendering the hook - store.set(stagedAICellsAtom, new Set([cellId1, cellId2])); + store.set( + stagedAICellsAtom, + new Map([ + [cellId1, { type: "add_cell" }], + [cellId2, { type: "add_cell" }], + ]), + ); const { result } = renderHook(() => useStagedCells(store)); result.current.deleteAllStagedCells(); @@ -192,76 +264,130 @@ describe("staged-cells", () => { // Verify cells were cleared from the atom const state = store.get(stagedAICellsAtom); - expect(state).toEqual(new Set()); + expect(state).toEqual(new Map()); }); - }); - it("should add staged cell", () => { - const { result } = renderHook(() => useStagedCells(store)); + it("should add staged cell with edit info", () => { + const { result } = renderHook(() => useStagedCells(store)); - result.current.addStagedCell({ cellId: cellId1 }); + result.current.addStagedCell({ + cellId: cellId1, + edit: { type: "update_cell", previousCode: "old code" }, + }); - // Check that the cell was added to the atom - const state = store.get(stagedAICellsAtom); - expect(state.has(cellId1)).toBe(true); - }); + // Check that the cell was added to the atom with edit info + const state = store.get(stagedAICellsAtom); + expect(state.has(cellId1)).toBe(true); + expect(state.get(cellId1)).toEqual({ + type: "update_cell", + previousCode: "old code", + }); + }); - it("should remove staged cell", () => { - const { result } = renderHook(() => useStagedCells(store)); + it("should remove staged cell", () => { + const { result } = renderHook(() => useStagedCells(store)); - // First add cells - result.current.addStagedCell({ cellId: cellId1 }); - result.current.addStagedCell({ cellId: cellId2 }); + // First add cells + result.current.addStagedCell({ + cellId: cellId1, + edit: { type: "add_cell" }, + }); + result.current.addStagedCell({ + cellId: cellId2, + edit: { type: "add_cell" }, + }); - // Then remove one - result.current.removeStagedCell(cellId1); + // Then remove one + result.current.removeStagedCell(cellId1); - // Check that only the remaining cell is in the map - const state = store.get(stagedAICellsAtom); - expect(state.has(cellId1)).toBe(false); - expect(state.has(cellId2)).toBe(true); - }); + // Check that only the remaining cell is in the map + const state = store.get(stagedAICellsAtom); + expect(state.has(cellId1)).toBe(false); + expect(state.has(cellId2)).toBe(true); + }); - it("should clear all staged cells", () => { - const { result } = renderHook(() => useStagedCells(store)); + it("should clear all staged cells", () => { + const { result } = renderHook(() => useStagedCells(store)); - // First add some cells - result.current.addStagedCell({ cellId: cellId1 }); - result.current.addStagedCell({ cellId: cellId2 }); + // First add some cells + result.current.addStagedCell({ + cellId: cellId1, + edit: { type: "add_cell" }, + }); + result.current.addStagedCell({ + cellId: cellId2, + edit: { type: "add_cell" }, + }); - // Then clear all - result.current.clearStagedCells(); + // Then clear all + result.current.clearStagedCells(); - // Check that no cells remain - const state = store.get(stagedAICellsAtom); - expect(state).toEqual(new Set()); - }); + // Check that no cells remain + const state = store.get(stagedAICellsAtom); + expect(state).toEqual(new Map()); + }); - it("should handle multiple operations correctly", () => { - const { result } = renderHook(() => useStagedCells(store)); + it("should handle multiple operations correctly", () => { + const { result } = renderHook(() => useStagedCells(store)); - // Create a staged cell - const mockCellId = "mock-cell-id" as CellId; - vi.mocked(CellId.create).mockReturnValue(mockCellId); + // Create a staged cell + const mockCellId = "mock-cell-id" as CellId; + vi.mocked(CellId.create).mockReturnValue(mockCellId); - const createdCellId = result.current.createStagedCell("test code"); + const createdCellId = result.current.createStagedCell("test code"); - // Verify it was created and added - expect(createdCellId).toBe(mockCellId); - expect(mockCreateNewCell).toHaveBeenCalled(); + // Verify it was created and added + expect(createdCellId).toBe(mockCellId); + expect(mockCreateNewCell).toHaveBeenCalled(); - let state = store.get(stagedAICellsAtom); - expect(state.has(mockCellId)).toBe(true); + let state = store.get(stagedAICellsAtom); + expect(state.has(mockCellId)).toBe(true); + expect(state.get(mockCellId)).toEqual({ type: "add_cell" }); - // Delete the staged cell - result.current.deleteStagedCell(mockCellId); - expect(mockDeleteCellCallback).toHaveBeenCalledWith({ - cellId: mockCellId, + // Delete the staged cell + result.current.deleteStagedCell(mockCellId); + expect(mockDeleteCellCallback).toHaveBeenCalledWith({ + cellId: mockCellId, + }); + + // Verify it was removed from staged cells + state = store.get(stagedAICellsAtom); + expect(state.has(mockCellId)).toBe(false); }); - // Verify it was removed from staged cells - state = store.get(stagedAICellsAtom); - expect(state.has(mockCellId)).toBe(false); + it("should track edit history for updated cells", () => { + const { result } = renderHook(() => useStagedCells(store)); + + // Add a cell with update_cell edit type + result.current.addStagedCell({ + cellId: cellId1, + edit: { type: "update_cell", previousCode: "previous code" }, + }); + + const state = store.get(stagedAICellsAtom); + const edit = state.get(cellId1); + expect(edit).toEqual({ + type: "update_cell", + previousCode: "previous code", + }); + }); + + it("should track edit history for deleted cells", () => { + const { result } = renderHook(() => useStagedCells(store)); + + // Add a cell with delete_cell edit type + result.current.addStagedCell({ + cellId: cellId1, + edit: { type: "delete_cell", previousCode: "deleted content" }, + }); + + const state = store.get(stagedAICellsAtom); + const edit = state.get(cellId1); + expect(edit).toEqual({ + type: "delete_cell", + previousCode: "deleted content", + }); + }); }); }); From 6f0c44194d812c5f76178ae8d35b0f81b035a08b Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Thu, 16 Oct 2025 20:47:23 +0800 Subject: [PATCH 04/10] type fix --- frontend/src/core/ai/__tests__/staged-cells.test.ts | 13 ++++++------- frontend/src/core/ai/staged-cells.ts | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/frontend/src/core/ai/__tests__/staged-cells.test.ts b/frontend/src/core/ai/__tests__/staged-cells.test.ts index 4a1b0ff5b67..3a2f9328693 100644 --- a/frontend/src/core/ai/__tests__/staged-cells.test.ts +++ b/frontend/src/core/ai/__tests__/staged-cells.test.ts @@ -6,6 +6,7 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; import { CellId } from "@/core/cells/ids"; import { updateEditorCodeFromPython } from "../../codemirror/language/utils"; import { + type StagedAICells, stagedAICellsAtom, useStagedCells, visibleForTesting, @@ -247,13 +248,11 @@ describe("staged-cells", () => { it("should delete all staged cells when cells exist", () => { // First set the atom state before rendering the hook - store.set( - stagedAICellsAtom, - new Map([ - [cellId1, { type: "add_cell" }], - [cellId2, { type: "add_cell" }], - ]), - ); + const initialState: StagedAICells = new Map([ + [cellId1, { type: "add_cell" }], + [cellId2, { type: "add_cell" }], + ]); + store.set(stagedAICellsAtom, initialState); const { result } = renderHook(() => useStagedCells(store)); result.current.deleteAllStagedCells(); diff --git a/frontend/src/core/ai/staged-cells.ts b/frontend/src/core/ai/staged-cells.ts index 5a2dd656c61..cafb79d9dc7 100644 --- a/frontend/src/core/ai/staged-cells.ts +++ b/frontend/src/core/ai/staged-cells.ts @@ -33,7 +33,7 @@ type Edit = | { type: Extract } | { type: Extract; previousCode: string }; -type StagedAICells = Map; +export type StagedAICells = Map; const initialState = (): StagedAICells => { return new Map(); From 1a2cb0c401f3245d3c1bf68e1fb5abcd58792dfd Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Thu, 16 Oct 2025 21:19:42 +0800 Subject: [PATCH 05/10] improve prompt --- .../src/core/ai/tools/edit-notebook-tool.ts | 4 +++ marimo/_server/ai/prompts.py | 26 ++++++++++++++----- .../ai/snapshots/chat_system_prompts.txt | 9 ++++++- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/frontend/src/core/ai/tools/edit-notebook-tool.ts b/frontend/src/core/ai/tools/edit-notebook-tool.ts index dad18ef7e23..42d5fcf59fc 100644 --- a/frontend/src/core/ai/tools/edit-notebook-tool.ts +++ b/frontend/src/core/ai/tools/edit-notebook-tool.ts @@ -44,6 +44,10 @@ const description: ToolDescription = { Pass { type: "__end__", columnId: columnId } to add the new cell at the end of the specified column. - delete_cell: Delete an existing cell, pass CellId. + For adding code, use the following guidelines: + - Markdown cells: use mo.md(f"""{content}""") function to insert content. + - SQL cells: use mo.sql(f"""{content}""") function to insert content. If a database engine is specified, use mo.sql(f"""{content}""", engine=engine) instead. + Returns: - A result object containing standard tool metadata.`, }; diff --git a/marimo/_server/ai/prompts.py b/marimo/_server/ai/prompts.py index df5dc278972..67d6df137e8 100644 --- a/marimo/_server/ai/prompts.py +++ b/marimo/_server/ai/prompts.py @@ -389,13 +389,27 @@ def get_chat_system_prompt( chart """ - for language in language_rules: - if len(language_rules[language]) == 0: - continue + if mode == "agent": + # In agent-mode, we can add how to insert cells into the notebook. + for lang in LANGUAGES: + # check if multiple_cells rules are present for this language + rule_to_add = language_rules_multiple_cells.get( + lang, language_rules.get(lang, []) + ) + if rule_to_add: + system_prompt += ( + f"\n\n## Rules for {lang}:\n{_rules(rule_to_add)}" + ) - system_prompt += ( - f"\n\n## Rules for {language}:\n{_rules(language_rules[language])}" - ) + system_prompt += "\n\n## Rules for inserting cells:\n" + system_prompt += 'For markdown cells, use `mo.md(f"""{content}""")`\n' + system_prompt += 'For sql cells, use `mo.sql(f"""{content}""")`. If a database engine is specified, use `mo.sql(f"""{content}""", engine=engine)` instead.\n' + else: + for language in language_rules: + if len(language_rules[language]) == 0: + continue + + system_prompt += f"\n\n## Rules for {language}:\n{_rules(language_rules[language])}" if custom_rules and custom_rules.strip(): system_prompt += f"\n\n## Additional rules:\n{custom_rules}" diff --git a/tests/_server/ai/snapshots/chat_system_prompts.txt b/tests/_server/ai/snapshots/chat_system_prompts.txt index 076475eec3a..99c3d312a9c 100644 --- a/tests/_server/ai/snapshots/chat_system_prompts.txt +++ b/tests/_server/ai/snapshots/chat_system_prompts.txt @@ -1027,7 +1027,14 @@ chart 7. If a variable is already defined, use another name, or make it private by adding an underscore at the beginning. ## Rules for sql: -1. The SQL must use duckdb syntax. +1. SQL cells start with df = mo.sql(f"""""") for DuckDB, or df = mo.sql(f"""""", engine=engine) for other SQL engines. +2. This will automatically display the result in the UI. You do not need to return the dataframe in the cell. +3. The SQL must use the syntax of the database engine specified in the `engine` variable. If no engine, then use duckdb syntax. + +## Rules for inserting cells: +For markdown cells, use `mo.md(f"""{content}""")` +For sql cells, use `mo.sql(f"""{content}""")`. If a database engine is specified, use `mo.sql(f"""{content}""", engine=engine)` instead. + ==================== kitchen sink ==================== From 8efb87627ef5466d2b5b213642e95840b27c163e Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Fri, 17 Oct 2025 01:56:45 +0800 Subject: [PATCH 06/10] add tests --- frontend/src/components/chat/chat-panel.tsx | 2 +- .../__tests__/edit-notebook-tool.test.ts | 551 ++++++++++++++++++ 2 files changed, 552 insertions(+), 1 deletion(-) create mode 100644 frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts diff --git a/frontend/src/components/chat/chat-panel.tsx b/frontend/src/components/chat/chat-panel.tsx index 79d5f1974ec..e92fc193def 100644 --- a/frontend/src/components/chat/chat-panel.tsx +++ b/frontend/src/components/chat/chat-panel.tsx @@ -322,7 +322,7 @@ const ChatInputFooter: React.FC = memo( }, { value: "agent", - label: "Agent", + label: "Agent (beta)", subtitle: "Use AI with access to read and write tools", }, ]; diff --git a/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts b/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts new file mode 100644 index 00000000000..47cb69595e4 --- /dev/null +++ b/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts @@ -0,0 +1,551 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import { EditorState } from "@codemirror/state"; +import { EditorView } from "@codemirror/view"; +import { getDefaultStore } from "jotai"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { MockNotebook } from "@/__mocks__/notebook"; +import { notebookAtom } from "@/core/cells/cells"; +import type { CellId } from "@/core/cells/ids"; +import { OverridingHotkeyProvider } from "@/core/hotkeys/hotkeys"; +import type { CellColumnId } from "@/utils/id-tree"; +import { MultiColumn } from "@/utils/id-tree"; +import { cellConfigExtension } from "../../../codemirror/config/extension"; +import { adaptiveLanguageConfiguration } from "../../../codemirror/language/extension"; +import { stagedAICellsAtom } from "../../staged-cells"; +import { ToolExecutionError } from "../base"; +import { EditNotebookTool } from "../edit-notebook-tool"; + +// Mock scrollAndHighlightCell +vi.mock("@/components/editor/links/cell-link", () => ({ + scrollAndHighlightCell: vi.fn(), +})); + +// Mock updateEditorCodeFromPython +vi.mock("@/core/codemirror/language/utils", () => ({ + updateEditorCodeFromPython: vi.fn(), +})); + +import { updateEditorCodeFromPython } from "@/core/codemirror/language/utils"; + +function createMockEditorView(code: string): EditorView { + return new EditorView({ + state: EditorState.create({ + doc: code, + extensions: [ + adaptiveLanguageConfiguration({ + cellId: "cell1" as CellId, + completionConfig: { + copilot: false, + activate_on_typing: true, + codeium_api_key: null, + }, + hotkeys: new OverridingHotkeyProvider({}), + placeholderType: "marimo-import", + lspConfig: {}, + }), + cellConfigExtension({ + completionConfig: { + copilot: false, + activate_on_typing: true, + codeium_api_key: null, + }, + hotkeys: new OverridingHotkeyProvider({}), + placeholderType: "marimo-import", + lspConfig: {}, + diagnosticsConfig: {}, + }), + ], + }), + }); +} + +describe("EditNotebookTool", () => { + let store: ReturnType; + let tool: EditNotebookTool; + let cellId1: CellId; + let cellId2: CellId; + let cellId3: CellId; + + beforeEach(() => { + store = getDefaultStore(); + tool = new EditNotebookTool(store); + + cellId1 = "cell-1" as CellId; + cellId2 = "cell-2" as CellId; + cellId3 = "cell-3" as CellId; + + // Reset mocks + vi.clearAllMocks(); + + // Reset atom states + store.set(stagedAICellsAtom, new Map()); + }); + + describe("tool metadata", () => { + it("should have correct metadata", () => { + expect(tool.name).toBe("edit_notebook_tool"); + expect(tool.description).toBeDefined(); + expect(tool.description.baseDescription).toContain("editing operations"); + expect(tool.schema).toBeDefined(); + expect(tool.outputSchema).toBeDefined(); + expect(tool.mode).toEqual(["agent"]); + }); + }); + + describe("update_cell operation", () => { + it("should update cell with new code", async () => { + const oldCode = "x = 1"; + const newCode = "x = 2"; + + // Create notebook state with mock editor view + const editorView = createMockEditorView(oldCode); + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: oldCode }, + }, + }); + notebook.cellHandles[cellId1] = { current: { editorView } } as never; + store.set(notebookAtom, notebook); + + const result = await tool.handler({ + edit: { + type: "update_cell", + cellId: cellId1, + code: newCode, + }, + }); + + expect(result.status).toBe("success"); + expect(vi.mocked(updateEditorCodeFromPython)).toHaveBeenCalledWith( + editorView, + newCode, + ); + + // Check that cell was staged + const stagedCells = store.get(stagedAICellsAtom); + expect(stagedCells.has(cellId1)).toBe(true); + expect(stagedCells.get(cellId1)).toEqual({ + type: "update_cell", + previousCode: oldCode, + }); + }); + + it("should throw error when cell ID doesn't exist", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + }, + }); + store.set(notebookAtom, notebook); + + await expect( + tool.handler({ + edit: { + type: "update_cell", + cellId: "nonexistent" as CellId, + code: "x = 2", + }, + }), + ).rejects.toThrow(ToolExecutionError); + + await expect( + tool.handler({ + edit: { + type: "update_cell", + cellId: "nonexistent" as CellId, + code: "x = 2", + }, + }), + ).rejects.toThrow("Cell not found"); + }); + + it("should throw error when cell editor not found", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + }, + }); + // Don't set editorView + notebook.cellHandles[cellId1] = { current: null } as never; + store.set(notebookAtom, notebook); + + await expect( + tool.handler({ + edit: { + type: "update_cell", + cellId: cellId1, + code: "x = 2", + }, + }), + ).rejects.toThrow("Cell editor not found"); + }); + }); + + describe("add_cell operation", () => { + it("should add cell at the end of the notebook", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + }, + }); + store.set(notebookAtom, notebook); + + const newCode = "y = 2"; + const result = await tool.handler({ + edit: { + type: "add_cell", + position: "__end__", + code: newCode, + }, + }); + + expect(result.status).toBe("success"); + + // Check that a new cell was staged + const stagedCells = store.get(stagedAICellsAtom); + expect(stagedCells.size).toBe(1); + const [cellId, edit] = [...stagedCells.entries()][0]; + expect(edit).toEqual({ type: "add_cell" }); + expect(cellId).toBeDefined(); + }); + + it("should add cell before a specific cell", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + [cellId2]: { code: "x = 2" }, + }, + }); + store.set(notebookAtom, notebook); + + const newCode = "y = 2"; + const result = await tool.handler({ + edit: { + type: "add_cell", + position: { cellId: cellId2, before: true }, + code: newCode, + }, + }); + + expect(result.status).toBe("success"); + + // Check that a new cell was staged + const stagedCells = store.get(stagedAICellsAtom); + expect(stagedCells.size).toBe(1); + const [_cellId, edit] = [...stagedCells.entries()][0]; + expect(edit).toEqual({ type: "add_cell" }); + }); + + it("should add cell after a specific cell", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + [cellId2]: { code: "x = 2" }, + }, + }); + store.set(notebookAtom, notebook); + + const newCode = "y = 2"; + const result = await tool.handler({ + edit: { + type: "add_cell", + position: { cellId: cellId2, before: false }, + code: newCode, + }, + }); + + expect(result.status).toBe("success"); + + // Check that a new cell was staged + const stagedCells = store.get(stagedAICellsAtom); + expect(stagedCells.size).toBe(1); + }); + + it("should add cell at the end of a specific column", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + [cellId2]: { code: "x = 2" }, + }, + }); + // Create multi-column layout + notebook.cellIds = MultiColumn.from([[cellId1], [cellId2]]); + const columnId = notebook.cellIds.getColumns()[1].id; + store.set(notebookAtom, notebook); + + const newCode = "y = 2"; + const result = await tool.handler({ + edit: { + type: "add_cell", + position: { type: "__end__", columnId }, + code: newCode, + }, + }); + + expect(result.status).toBe("success"); + + // Check that a new cell was staged + const stagedCells = store.get(stagedAICellsAtom); + expect(stagedCells.size).toBe(1); + }); + + it("should throw error when cell ID doesn't exist for position", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + }, + }); + store.set(notebookAtom, notebook); + + await expect( + tool.handler({ + edit: { + type: "add_cell", + position: { cellId: "nonexistent" as CellId, before: true }, + code: "y = 2", + }, + }), + ).rejects.toThrow("Cell not found"); + }); + + it("should throw error when column ID doesn't exist", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + }, + }); + store.set(notebookAtom, notebook); + + await expect( + tool.handler({ + edit: { + type: "add_cell", + position: { + type: "__end__", + columnId: "nonexistent" as CellColumnId, + }, + code: "y = 2", + }, + }), + ).rejects.toThrow("Column not found"); + }); + }); + + describe("delete_cell operation", () => { + it("should delete a cell", async () => { + const cellCode = "x = 1"; + const editorView = createMockEditorView(cellCode); + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: cellCode }, + [cellId2]: { code: "x = 2" }, + }, + }); + notebook.cellHandles[cellId1] = { current: { editorView } } as never; + store.set(notebookAtom, notebook); + + const result = await tool.handler({ + edit: { + type: "delete_cell", + cellId: cellId1, + }, + }); + + expect(result.status).toBe("success"); + + // Check that cell was staged for deletion + const stagedCells = store.get(stagedAICellsAtom); + expect(stagedCells.has(cellId1)).toBe(true); + expect(stagedCells.get(cellId1)).toEqual({ + type: "delete_cell", + previousCode: cellCode, + }); + }); + + it("should throw error when cell ID doesn't exist", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + }, + }); + store.set(notebookAtom, notebook); + + await expect( + tool.handler({ + edit: { + type: "delete_cell", + cellId: "nonexistent" as CellId, + }, + }), + ).rejects.toThrow("Cell not found"); + }); + + it("should throw error when cell editor not found", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + }, + }); + // Don't set editorView + notebook.cellHandles[cellId1] = { current: null } as never; + store.set(notebookAtom, notebook); + + await expect( + tool.handler({ + edit: { + type: "delete_cell", + cellId: cellId1, + }, + }), + ).rejects.toThrow("Cell editor not found"); + }); + }); + + describe("validation", () => { + it("should validate cell exists in multi-column notebook", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + [cellId2]: { code: "x = 2" }, + [cellId3]: { code: "x = 3" }, + }, + }); + // Create multi-column layout + notebook.cellIds = MultiColumn.from([[cellId1, cellId2], [cellId3]]); + store.set(notebookAtom, notebook); + + // Should not throw for cells in different columns + const editorView = createMockEditorView("x = 1"); + notebook.cellHandles[cellId1] = { current: { editorView } } as never; + notebook.cellHandles[cellId3] = { current: { editorView } } as never; + + await expect( + tool.handler({ + edit: { + type: "update_cell", + cellId: cellId1, + code: "y = 1", + }, + }), + ).resolves.toBeDefined(); + + await expect( + tool.handler({ + edit: { + type: "update_cell", + cellId: cellId3, + code: "y = 3", + }, + }), + ).resolves.toBeDefined(); + }); + }); + + describe("return value", () => { + it("should return success status with next steps", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = 1" }, + }, + }); + store.set(notebookAtom, notebook); + + const result = await tool.handler({ + edit: { + type: "add_cell", + position: "__end__", + code: "y = 2", + }, + }); + + expect(result.status).toBe("success"); + expect(result.next_steps).toBeDefined(); + expect(Array.isArray(result.next_steps)).toBe(true); + expect(result.next_steps?.length).toBeGreaterThan(0); + }); + }); + + describe("schema validation", () => { + it("should validate update_cell input", () => { + const validInput = { + edit: { + type: "update_cell", + cellId: cellId1, + code: "x = 2", + }, + }; + + const result = tool.schema.safeParse(validInput); + expect(result.success).toBe(true); + }); + + it("should validate add_cell input with __end__ position", () => { + const validInput = { + edit: { + type: "add_cell", + position: "__end__", + code: "x = 2", + }, + }; + + const result = tool.schema.safeParse(validInput); + expect(result.success).toBe(true); + }); + + it("should validate add_cell input with cellId position", () => { + const validInput = { + edit: { + type: "add_cell", + position: { + cellId: cellId1, + before: true, + }, + code: "x = 2", + }, + }; + + const result = tool.schema.safeParse(validInput); + expect(result.success).toBe(true); + }); + + it("should validate add_cell input with columnId position", () => { + const validInput = { + edit: { + type: "add_cell", + position: { + type: "__end__", + columnId: "column-1" as CellColumnId, + }, + code: "x = 2", + }, + }; + + const result = tool.schema.safeParse(validInput); + expect(result.success).toBe(true); + }); + + it("should validate delete_cell input", () => { + const validInput = { + edit: { + type: "delete_cell", + cellId: cellId1, + }, + }; + + const result = tool.schema.safeParse(validInput); + expect(result.success).toBe(true); + }); + + it("should reject invalid input", () => { + const invalidInput = { + edit: { + type: "invalid_type", + }, + }; + + const result = tool.schema.safeParse(invalidInput); + expect(result.success).toBe(false); + }); + }); +}); From 6c2ec437ec0041d8316306bfb18f0d125bafd823 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Fri, 17 Oct 2025 02:02:20 +0800 Subject: [PATCH 07/10] remove some tests --- .../__tests__/edit-notebook-tool.test.ts | 83 ------------------- 1 file changed, 83 deletions(-) diff --git a/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts b/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts index 47cb69595e4..333df23e208 100644 --- a/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts +++ b/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts @@ -465,87 +465,4 @@ describe("EditNotebookTool", () => { expect(result.next_steps?.length).toBeGreaterThan(0); }); }); - - describe("schema validation", () => { - it("should validate update_cell input", () => { - const validInput = { - edit: { - type: "update_cell", - cellId: cellId1, - code: "x = 2", - }, - }; - - const result = tool.schema.safeParse(validInput); - expect(result.success).toBe(true); - }); - - it("should validate add_cell input with __end__ position", () => { - const validInput = { - edit: { - type: "add_cell", - position: "__end__", - code: "x = 2", - }, - }; - - const result = tool.schema.safeParse(validInput); - expect(result.success).toBe(true); - }); - - it("should validate add_cell input with cellId position", () => { - const validInput = { - edit: { - type: "add_cell", - position: { - cellId: cellId1, - before: true, - }, - code: "x = 2", - }, - }; - - const result = tool.schema.safeParse(validInput); - expect(result.success).toBe(true); - }); - - it("should validate add_cell input with columnId position", () => { - const validInput = { - edit: { - type: "add_cell", - position: { - type: "__end__", - columnId: "column-1" as CellColumnId, - }, - code: "x = 2", - }, - }; - - const result = tool.schema.safeParse(validInput); - expect(result.success).toBe(true); - }); - - it("should validate delete_cell input", () => { - const validInput = { - edit: { - type: "delete_cell", - cellId: cellId1, - }, - }; - - const result = tool.schema.safeParse(validInput); - expect(result.success).toBe(true); - }); - - it("should reject invalid input", () => { - const invalidInput = { - edit: { - type: "invalid_type", - }, - }; - - const result = tool.schema.safeParse(invalidInput); - expect(result.success).toBe(false); - }); - }); }); From a82733131fcd559712a3d47a3aa161ff21428379 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Fri, 17 Oct 2025 02:22:09 +0800 Subject: [PATCH 08/10] remove deleting cells for now --- .../components/editor/cell/StagedAICell.tsx | 1 + .../__tests__/edit-notebook-tool.test.ts | 6 ++-- .../src/core/ai/tools/edit-notebook-tool.ts | 33 +++++++++++-------- 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/frontend/src/components/editor/cell/StagedAICell.tsx b/frontend/src/components/editor/cell/StagedAICell.tsx index 070879da12a..dbc0616aa58 100644 --- a/frontend/src/components/editor/cell/StagedAICell.tsx +++ b/frontend/src/components/editor/cell/StagedAICell.tsx @@ -59,6 +59,7 @@ export const StagedAICellFooter: React.FC<{ cellId: CellId }> = ({ break; case "delete_cell": // TODO: Revert delete + removeStagedCell(cellId); break; } }; diff --git a/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts b/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts index 333df23e208..ba9e1898571 100644 --- a/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts +++ b/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts @@ -333,7 +333,7 @@ describe("EditNotebookTool", () => { }); describe("delete_cell operation", () => { - it("should delete a cell", async () => { + it.fails("should delete a cell", async () => { const cellCode = "x = 1"; const editorView = createMockEditorView(cellCode); const notebook = MockNotebook.notebookState({ @@ -363,7 +363,7 @@ describe("EditNotebookTool", () => { }); }); - it("should throw error when cell ID doesn't exist", async () => { + it.fails("should throw error when cell ID doesn't exist", async () => { const notebook = MockNotebook.notebookState({ cellData: { [cellId1]: { code: "x = 1" }, @@ -381,7 +381,7 @@ describe("EditNotebookTool", () => { ).rejects.toThrow("Cell not found"); }); - it("should throw error when cell editor not found", async () => { + it.fails("should throw error when cell editor not found", async () => { const notebook = MockNotebook.notebookState({ cellData: { [cellId1]: { code: "x = 1" }, diff --git a/frontend/src/core/ai/tools/edit-notebook-tool.ts b/frontend/src/core/ai/tools/edit-notebook-tool.ts index 42d5fcf59fc..07bbfc6bafa 100644 --- a/frontend/src/core/ai/tools/edit-notebook-tool.ts +++ b/frontend/src/core/ai/tools/edit-notebook-tool.ts @@ -175,23 +175,28 @@ export class EditNotebookTool break; } - case "delete_cell": { - const { cellId } = edit; - const notebook = this.store.get(notebookAtom); - this.validateCellIdExists(cellId, notebook); + case "delete_cell": + return { + status: "error", + message: "Deleting cells are not supported yet", + }; - const editorView = this.getCellEditorView(cellId, notebook); - const currentCellCode = editorView.state.doc.toString(); + // const { cellId } = edit; - // Add to staged AICells - this.stagedAICellsActions.addStagedCell({ - cellId, - edit: { type: "delete_cell", previousCode: currentCellCode }, - }); + // const notebook = this.store.get(notebookAtom); + // this.validateCellIdExists(cellId, notebook); - this.notebookActions.deleteCell({ cellId }); - break; - } + // const editorView = this.getCellEditorView(cellId, notebook); + // const currentCellCode = editorView.state.doc.toString(); + + // // Add to staged AICells + // this.stagedAICellsActions.addStagedCell({ + // cellId, + // edit: { type: "delete_cell", previousCode: currentCellCode }, + // }); + + // this.notebookActions.deleteCell({ cellId }); + // break; } return { status: "success", From 74d154c05c00e8d12b79859e33d5a71a4983d9dc Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Fri, 17 Oct 2025 22:14:07 +0800 Subject: [PATCH 09/10] add floating banner for AI completions & support deletes (#6825) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## 📝 Summary https://github.com/user-attachments/assets/05af6efe-65c2-4312-9571-f0cb88c378a6 ## 🔍 Description of Changes ## 📋 Checklist - [x] I have read the [contributor guidelines](https://github.com/marimo-team/marimo/blob/main/CONTRIBUTING.md). - [x] For large changes, or changes that affect the public API: this change was discussed or approved through an issue, on [Discord](https://marimo.io/discord?ref=pr), or the community [discussions](https://github.com/marimo-team/marimo/discussions) (Please provide a link if applicable). - [x] I have added tests for the changes made. - [x] I have run the code and verified that it works as expected. --- .../editor/ai/completion-handlers.tsx | 17 ++- .../components/editor/cell/StagedAICell.tsx | 104 ++++++++++++------ .../editor/chrome/wrapper/app-chrome.tsx | 2 + .../editor/chrome/wrapper/minimap.tsx | 2 + .../chrome/wrapper/pending-ai-cells.tsx | 100 +++++++++++++++++ frontend/src/core/ai/staged-cells.ts | 2 +- .../__tests__/edit-notebook-tool.test.ts | 6 +- .../src/core/ai/tools/edit-notebook-tool.ts | 36 +++--- frontend/src/css/app/Cell.css | 31 ++++++ frontend/src/utils/__tests__/arrays.test.ts | 42 +++++++ frontend/src/utils/arrays.ts | 25 +++++ 11 files changed, 307 insertions(+), 60 deletions(-) create mode 100644 frontend/src/components/editor/chrome/wrapper/pending-ai-cells.tsx diff --git a/frontend/src/components/editor/ai/completion-handlers.tsx b/frontend/src/components/editor/ai/completion-handlers.tsx index 7b061b7a055..f4bd85fa638 100644 --- a/frontend/src/components/editor/ai/completion-handlers.tsx +++ b/frontend/src/components/editor/ai/completion-handlers.tsx @@ -114,6 +114,7 @@ export const CompletionActionsCellFooter: React.FC<{ export const AcceptCompletionButton: React.FC<{ isLoading: boolean; onAccept: () => void; + text?: string; size?: "xs" | "sm"; buttonStyles?: string; playButtonStyles?: string; @@ -122,6 +123,7 @@ export const AcceptCompletionButton: React.FC<{ }> = ({ isLoading, onAccept, + text = "Accept", size = "sm", buttonStyles, acceptShortcut, @@ -150,7 +152,7 @@ export const AcceptCompletionButton: React.FC<{ onClick={onAccept} className={`${baseClasses} rounded-r-none ${buttonStyles}`} > - Accept + {text} {acceptShortcut && ( )} @@ -178,7 +180,7 @@ export const AcceptCompletionButton: React.FC<{ onClick={onAccept} className={`${baseClasses} rounded px-3 ${buttonStyles}`} > - Accept + {text} {acceptShortcut && ( )} @@ -188,10 +190,17 @@ export const AcceptCompletionButton: React.FC<{ export const RejectCompletionButton: React.FC<{ onDecline: () => void; + text?: string; size?: "xs" | "sm"; className?: string; declineShortcut?: string; -}> = ({ onDecline, size = "sm", className, declineShortcut }) => { +}> = ({ + onDecline, + text = "Reject", + size = "sm", + className, + declineShortcut, +}) => { return ( + + {currentIndex === null + ? `${listStagedCells.length} pending` + : `${currentIndex + 1} / ${listStagedCells.length}`} + + + + +
+ +
+ + +
+
+ ); +}; diff --git a/frontend/src/core/ai/staged-cells.ts b/frontend/src/core/ai/staged-cells.ts index cafb79d9dc7..f25ac18a654 100644 --- a/frontend/src/core/ai/staged-cells.ts +++ b/frontend/src/core/ai/staged-cells.ts @@ -28,7 +28,7 @@ import type { EditType } from "./tools/edit-notebook-tool"; * And we only track one set of staged cells at a time. */ -type Edit = +export type Edit = | { type: Extract; previousCode: string } | { type: Extract } | { type: Extract; previousCode: string }; diff --git a/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts b/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts index ba9e1898571..333df23e208 100644 --- a/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts +++ b/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts @@ -333,7 +333,7 @@ describe("EditNotebookTool", () => { }); describe("delete_cell operation", () => { - it.fails("should delete a cell", async () => { + it("should delete a cell", async () => { const cellCode = "x = 1"; const editorView = createMockEditorView(cellCode); const notebook = MockNotebook.notebookState({ @@ -363,7 +363,7 @@ describe("EditNotebookTool", () => { }); }); - it.fails("should throw error when cell ID doesn't exist", async () => { + it("should throw error when cell ID doesn't exist", async () => { const notebook = MockNotebook.notebookState({ cellData: { [cellId1]: { code: "x = 1" }, @@ -381,7 +381,7 @@ describe("EditNotebookTool", () => { ).rejects.toThrow("Cell not found"); }); - it.fails("should throw error when cell editor not found", async () => { + it("should throw error when cell editor not found", async () => { const notebook = MockNotebook.notebookState({ cellData: { [cellId1]: { code: "x = 1" }, diff --git a/frontend/src/core/ai/tools/edit-notebook-tool.ts b/frontend/src/core/ai/tools/edit-notebook-tool.ts index 07bbfc6bafa..f9739e15d5f 100644 --- a/frontend/src/core/ai/tools/edit-notebook-tool.ts +++ b/frontend/src/core/ai/tools/edit-notebook-tool.ts @@ -32,7 +32,7 @@ const description: ToolDescription = { baseDescription: "Perform editing operations on the current notebook. Call this tool multiple times to perform multiple edits.", prerequisites: [ - "Find out the cellIds and columnIds first (call get_lightweight_cell_map tool)", + "Find out the cellIds and columnIds first (call lightweight cell map tool)", ], additionalInfo: ` Args: @@ -42,7 +42,7 @@ const description: ToolDescription = { Pass "__end__" to add the new cell at the end of the notebook. Pass { cellId: cellId, before: true } to add the new cell before the specified cell. And before: false if after the specified cell. Pass { type: "__end__", columnId: columnId } to add the new cell at the end of the specified column. - - delete_cell: Delete an existing cell, pass CellId. + - delete_cell: Delete an existing cell, pass CellId. For deleting cells, the user needs to accept the deletion to actually delete the cell, so you may still see the cell in the notebook on subsequent edits which is fine. For adding code, use the following guidelines: - Markdown cells: use mo.md(f"""{content}""") function to insert content. @@ -175,28 +175,24 @@ export class EditNotebookTool break; } - case "delete_cell": - return { - status: "error", - message: "Deleting cells are not supported yet", - }; + case "delete_cell": { + const { cellId } = edit; - // const { cellId } = edit; - - // const notebook = this.store.get(notebookAtom); - // this.validateCellIdExists(cellId, notebook); + const notebook = this.store.get(notebookAtom); + this.validateCellIdExists(cellId, notebook); - // const editorView = this.getCellEditorView(cellId, notebook); - // const currentCellCode = editorView.state.doc.toString(); + const editorView = this.getCellEditorView(cellId, notebook); + const currentCellCode = editorView.state.doc.toString(); - // // Add to staged AICells - // this.stagedAICellsActions.addStagedCell({ - // cellId, - // edit: { type: "delete_cell", previousCode: currentCellCode }, - // }); + // Add to staged AICells - don't actually delete the cell yet + this.stagedAICellsActions.addStagedCell({ + cellId, + edit: { type: "delete_cell", previousCode: currentCellCode }, + }); - // this.notebookActions.deleteCell({ cellId }); - // break; + scrollAndHighlightCell(cellId); + break; + } } return { status: "success", diff --git a/frontend/src/css/app/Cell.css b/frontend/src/css/app/Cell.css index e75e965f154..8a090ef3ef4 100644 --- a/frontend/src/css/app/Cell.css +++ b/frontend/src/css/app/Cell.css @@ -34,6 +34,37 @@ z-index: 1; } + /* Styling for cells marked for deletion by AI */ + &:has(.mo-ai-deleted-cell) { + opacity: 0.7; + position: relative; + + .output-area, + .cm-gutters, + .cm { + background-color: var(--red-2); + } + + /* Add a red glow to indicate deletion */ + &::before { + content: ""; + position: absolute; + inset: 0; + border-radius: 10px; + pointer-events: none; + opacity: 0.6; + box-shadow: 0px 0px 10px 0px var(--red-9); + z-index: 1; + } + + /* Add strikethrough effect to the code */ + .cm-content { + text-decoration: line-through; + text-decoration-color: var(--red-9); + text-decoration-thickness: 1px; + } + } + &:focus-within { z-index: 20; } diff --git a/frontend/src/utils/__tests__/arrays.test.ts b/frontend/src/utils/__tests__/arrays.test.ts index 47da4be94ae..d5108086492 100644 --- a/frontend/src/utils/__tests__/arrays.test.ts +++ b/frontend/src/utils/__tests__/arrays.test.ts @@ -8,6 +8,7 @@ import { arrayMove, arrayShallowEquals, arrayToggle, + getNextIndex, } from "../arrays"; describe("arrays", () => { @@ -128,3 +129,44 @@ describe("arrays", () => { }); }); }); + +describe("getNextIndex", () => { + it("should return 0 if listLength is 0", () => { + expect(getNextIndex(null, 0, "up")).toBe(0); + expect(getNextIndex(null, 0, "down")).toBe(0); + expect(getNextIndex(0, 0, "up")).toBe(0); + expect(getNextIndex(0, 0, "down")).toBe(0); + }); + + it("should return 0 if currentIndex is null and direction is up", () => { + expect(getNextIndex(null, 10, "up")).toBe(0); + }); + + it("should return last index if currentIndex is null and direction is down", () => { + expect(getNextIndex(null, 10, "down")).toBe(9); + }); + + it("should return the next index in the list", () => { + expect(getNextIndex(0, 10, "up")).toBe(1); + }); + + it("should wrap around to the start of the list", () => { + expect(getNextIndex(9, 10, "up")).toBe(0); + }); + + it("should return next index in middle of list", () => { + expect(getNextIndex(5, 10, "up")).toBe(6); + }); + + it("should return previous index in middle of list", () => { + expect(getNextIndex(5, 10, "down")).toBe(4); + }); + + it("should wrap around to the end of the list", () => { + expect(getNextIndex(0, 10, "down")).toBe(9); + }); + + it("should return the previous index in the list", () => { + expect(getNextIndex(1, 10, "down")).toBe(0); + }); +}); diff --git a/frontend/src/utils/arrays.ts b/frontend/src/utils/arrays.ts index b3107806b02..f1fa2015380 100644 --- a/frontend/src/utils/arrays.ts +++ b/frontend/src/utils/arrays.ts @@ -86,3 +86,28 @@ export function uniqueBy(arr: T[], key: (item: T) => string): T[] { } return result; } + +/** + * Get the next index in the list, wrapping around to the start or end if necessary. + * @param currentIndex - The current index, or null if there is no current index. + * @param listLength - The length of the list. + * @param direction - The direction to move in. + * @returns The next index. + */ +export function getNextIndex( + currentIndex: number | null, + listLength: number, + direction: "up" | "down", +): number { + if (listLength === 0) { + return 0; + } + + if (currentIndex === null) { + return direction === "up" ? 0 : listLength - 1; + } + + return direction === "up" + ? (currentIndex + 1) % listLength + : (currentIndex - 1 + listLength) % listLength; +} From 74b304558c8d9f200e90942b946e59c17e0e3b5c Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Fri, 17 Oct 2025 22:16:50 +0800 Subject: [PATCH 10/10] apply nit --- frontend/src/components/editor/ai/ai-completion-editor.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/src/components/editor/ai/ai-completion-editor.tsx b/frontend/src/components/editor/ai/ai-completion-editor.tsx index f02f500fa89..a5c5084f422 100644 --- a/frontend/src/components/editor/ai/ai-completion-editor.tsx +++ b/frontend/src/components/editor/ai/ai-completion-editor.tsx @@ -193,7 +193,7 @@ export const AiCompletionEditor: React.FC = ({ const showCompletionBanner = enabled && triggerImmediately && (completion || isLoading); // Set default output area to below if not specified - outputArea = outputArea === undefined ? "below" : outputArea; + outputArea = outputArea ?? "below"; const showInput = enabled && (!triggerImmediately || showInputPrompt);