diff --git a/frontend/src/components/editor/ai/ai-completion-editor.tsx b/frontend/src/components/editor/ai/ai-completion-editor.tsx index 91e0344219b..1ba5f9ed13b 100644 --- a/frontend/src/components/editor/ai/ai-completion-editor.tsx +++ b/frontend/src/components/editor/ai/ai-completion-editor.tsx @@ -409,14 +409,14 @@ const CompletionBanner: React.FC = ({
{isLoading ? ( ) : ( { const getOutputString = (): string => { const text = consoleOutputs .filter((output) => output.channel !== "pdb") - .map((output) => { - if ( - output.mimetype.startsWith("application/vnd.marimo") || - output.mimetype === "text/html" - ) { - return parseHtmlContent(Strings.asString(output.data)); - } - - // Convert ANSI to HTML, then parse as HTML - return ansiToPlainText(Strings.asString(output.data)); - }) + .map((output) => processOutput(output)) .join("\n"); return text; }; @@ -282,3 +272,16 @@ const renderText = (text: string | null) => { ); }; + +/** Convert cell or console output to a string, while handling html and ansi codes */ +export const processOutput = (output: OutputMessage): string => { + if ( + output.mimetype.startsWith("application/vnd.marimo") || + output.mimetype === "text/html" + ) { + return parseHtmlContent(Strings.asString(output.data)); + } + + // Convert ANSI to HTML, then parse as HTML + return ansiToPlainText(Strings.asString(output.data)); +}; diff --git a/frontend/src/core/ai/context/providers/__tests__/cell-output.test.ts b/frontend/src/core/ai/context/providers/__tests__/cell-output.test.ts index f02e6b7017e..b4fe9227eef 100644 --- a/frontend/src/core/ai/context/providers/__tests__/cell-output.test.ts +++ b/frontend/src/core/ai/context/providers/__tests__/cell-output.test.ts @@ -1,6 +1,6 @@ /* Copyright 2024 Marimo. All rights reserved. */ -import { beforeEach, describe, expect, it, vi } from "vitest"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; import { Mocks } from "@/__mocks__/common"; // Mock the external dependencies @@ -320,8 +320,12 @@ describe("Cell output utility functions", () => { describe("HTML content parsing", () => { let provider: CellOutputContextProvider; let mockStore: JotaiStore; + let originalCreateElement: typeof document.createElement; beforeEach(() => { + // Save original createElement + originalCreateElement = document.createElement; + // Mock DOM methods for HTML parsing const mockDiv = { innerHTML: "", @@ -343,6 +347,13 @@ describe("Cell output utility functions", () => { provider = new CellOutputContextProvider(mockStore); }); + afterEach(() => { + // Restore original createElement to prevent mock pollution + if (originalCreateElement) { + document.createElement = originalCreateElement; + } + }); + it("should extract text content from HTML", () => { const htmlContent = "

Hello world!

"; diff --git a/frontend/src/core/ai/context/providers/cell-output.ts b/frontend/src/core/ai/context/providers/cell-output.ts index c22aad4c207..0ad858000e4 100644 --- a/frontend/src/core/ai/context/providers/cell-output.ts +++ b/frontend/src/core/ai/context/providers/cell-output.ts @@ -3,13 +3,13 @@ import type { Completion } from "@codemirror/autocomplete"; import type { FileUIPart } from "ai"; import { toPng } from "html-to-image"; +import { processOutput } from "@/components/editor/output/ConsoleOutput"; import { type NotebookState, notebookAtom } from "@/core/cells/cells"; import { type CellId, CellOutputId } from "@/core/cells/ids"; import { displayCellName } from "@/core/cells/names"; import { isOutputEmpty } from "@/core/cells/outputs"; import type { OutputMessage } from "@/core/kernel/messages"; import type { JotaiStore } from "@/core/state/jotai"; -import { parseHtmlContent } from "@/utils/dom"; import { Logger } from "@/utils/Logger"; import { type AIContextItem, AIContextProvider } from "../registry"; import { contextToXml } from "../utils"; @@ -320,16 +320,10 @@ function getBaseOutput(output: OutputMessage): BaseOutput | null { const isMedia = isMediaMimetype(mimetype, String(output.data)); const outputType = isMedia ? "media" : "text"; - let processedContent: string | undefined; + const processedContent = processOutput(output); let imageUrl: string | undefined; let shouldDownloadImage = false; - // Process text content - if (outputType === "text" && typeof output.data === "string") { - processedContent = - mimetype === "text/html" ? parseHtmlContent(output.data) : output.data; - } - // Process media content - for now, we'll just note that it's media if (outputType === "media") { if (typeof output.data === "string" && output.data.startsWith("data:")) { diff --git a/frontend/src/core/ai/tools/__tests__/run-cells-tool.test.ts b/frontend/src/core/ai/tools/__tests__/run-cells-tool.test.ts index d4299858a78..fc1d23960f1 100644 --- a/frontend/src/core/ai/tools/__tests__/run-cells-tool.test.ts +++ b/frontend/src/core/ai/tools/__tests__/run-cells-tool.test.ts @@ -44,7 +44,7 @@ describe("RunStaleCellsTool", () => { store, }; - tool = new RunStaleCellsTool(); + tool = new RunStaleCellsTool({ postExecutionDelay: 0 }); cellId1 = "cell-1" as CellId; cellId2 = "cell-2" as CellId; diff --git a/frontend/src/core/ai/tools/edit-notebook-tool.ts b/frontend/src/core/ai/tools/edit-notebook-tool.ts index fd62869e9b9..b390ffd3916 100644 --- a/frontend/src/core/ai/tools/edit-notebook-tool.ts +++ b/frontend/src/core/ai/tools/edit-notebook-tool.ts @@ -109,15 +109,19 @@ export class EditNotebookTool scrollAndHighlightCell(cellId); - // If previous code exists, we don't want to replace it, it means there is a new edit on top of the previous edit + const existingStagedCell = store.get(stagedAICellsAtom).get(cellId); + + // If previous edit was from a new cell, just replace editor code with the new code + if (existingStagedCell?.type === "add_cell") { + updateEditorCodeFromPython(editorView, code); + break; + } + + // If previous code exists, we don't want to replace it, it means there is a new edit on top of the previous update // Keep the original code - const stagedCell = store.get(stagedAICellsAtom).get(cellId); const currentCellCode = editorView.state.doc.toString(); const previousCode = - stagedCell?.type === "update_cell" || - stagedCell?.type === "delete_cell" - ? stagedCell.previousCode - : currentCellCode; + existingStagedCell?.previousCode ?? currentCellCode; addStagedCell({ cellId, @@ -186,7 +190,11 @@ export class EditNotebookTool } return { status: "success", - next_steps: ["If you need to perform more edits, call this tool again."], + next_steps: [ + "If you need to perform more edits, call this tool again.", + "You should use the lint notebook tool to check for errors and lint issues. Fix them by editing the notebook.", + "You should use the run stale cells tool to run the cells that have been edited or newly added. This allows you to see the output of the cells and fix any errors.", + ], }; }; diff --git a/frontend/src/core/ai/tools/run-cells-tool.ts b/frontend/src/core/ai/tools/run-cells-tool.ts index ac6796d4da7..b1d178b7beb 100644 --- a/frontend/src/core/ai/tools/run-cells-tool.ts +++ b/frontend/src/core/ai/tools/run-cells-tool.ts @@ -21,11 +21,12 @@ import { } from "./base"; import type { CopilotMode } from "./registry"; +const POST_EXECUTION_DELAY = 200; +const WAIT_FOR_CELLS_TIMEOUT = 30_000; + interface CellOutput { consoleOutput?: string; - // consoleAttachments?: FileUIPart[]; cellOutput?: string; - // cellAttachments?: FileUIPart[]; } // Must use Record instead of Map because Map serializes to JSON as {} @@ -67,6 +68,12 @@ export class RunStaleCellsTool }) satisfies z.ZodType; readonly mode: CopilotMode[] = ["agent"]; + private readonly postExecutionDelay: number; + + constructor(opts?: { postExecutionDelay?: number }) { + this.postExecutionDelay = opts?.postExecutionDelay ?? POST_EXECUTION_DELAY; + } + handler = async ( _args: EmptyToolInput, toolContext: ToolNotebookContext, @@ -91,7 +98,12 @@ export class RunStaleCellsTool }); // Wait for all cells to finish executing - const allCellsFinished = await this.waitForCellsToFinish(store, staleCells); + const allCellsFinished = await this.waitForCellsToFinish( + store, + staleCells, + WAIT_FOR_CELLS_TIMEOUT, + this.postExecutionDelay, + ); if (!allCellsFinished) { return { status: "success", @@ -105,6 +117,7 @@ export class RunStaleCellsTool const cellsToOutput = new Map(); let resultMessage = ""; + let outputHasErrors = false; for (const cellId of staleCells) { const cellContextData = getCellContextData(cellId, updatedNotebook, { @@ -112,13 +125,13 @@ export class RunStaleCellsTool }); let cellOutputString: string | undefined; - // let cellAttachments: FileUIPart[] | undefined; let consoleOutputString: string | undefined; - // let consoleAttachments: FileUIPart[] | undefined; const cellOutput = cellContextData.cellOutput; const consoleOutputs = cellContextData.consoleOutputs; - if (!cellOutput && !consoleOutputs) { + const hasConsoleOutput = consoleOutputs && consoleOutputs.length > 0; + + if (!cellOutput && !hasConsoleOutput) { // Set null to show no output cellsToOutput.set(cellId, null); continue; @@ -126,31 +139,26 @@ export class RunStaleCellsTool if (cellOutput) { cellOutputString = this.formatOutputString(cellOutput); - // cellAttachments = await getAttachmentsForOutputs( - // [cellOutput], - // cellId, - // cellContextData.cellName, - // ); + if (this.outputHasErrors(cellOutput)) { + outputHasErrors = true; + } } - if (consoleOutputs) { - // consoleAttachments = await getAttachmentsForOutputs( - // consoleOutputs, - // cellId, - // cellContextData.cellName, - // ); + if (hasConsoleOutput) { consoleOutputString = consoleOutputs .map((output) => this.formatOutputString(output)) .join("\n"); resultMessage += "Console output represents the stdout or stderr of the cell (eg. print statements)."; + + if (consoleOutputs.some((output) => this.outputHasErrors(output))) { + outputHasErrors = true; + } } cellsToOutput.set(cellId, { cellOutput: cellOutputString, consoleOutput: consoleOutputString, - // cellAttachments: cellAttachments, - // consoleAttachments: consoleAttachments, }); } @@ -161,16 +169,32 @@ export class RunStaleCellsTool }; } + const nextSteps = [ + "Review the output of the cells. The CellId is the key of the result object.", + outputHasErrors + ? "There are errors in the cells. Please fix them by using the edit notebook tool and the given CellIds." + : "You may edit the notebook further with the given CellIds.", + ]; + return { status: "success", cellsToOutput: Object.fromEntries(cellsToOutput), message: resultMessage === "" ? undefined : resultMessage, - next_steps: [ - "Review the output of the cells, if you need to make any changes, you can run this tool again, the CellId is the key of the result object if you want to edit the cell.", - ], + next_steps: nextSteps, }; }; + private outputHasErrors(cellOutput: BaseOutput): boolean { + const { output } = cellOutput; + if ( + output.mimetype === "application/vnd.marimo+error" || + output.mimetype === "application/vnd.marimo+traceback" + ) { + return true; + } + return false; + } + private formatOutputString(cellOutput: BaseOutput): string { let outputString = ""; const { outputType, processedContent, imageUrl, output } = cellOutput; @@ -199,7 +223,8 @@ export class RunStaleCellsTool private async waitForCellsToFinish( store: JotaiStore, cellIds: CellId[], - timeout = 30_000, + timeout: number, + postExecutionDelay: number, ): Promise { const checkAllFinished = ( notebook: ReturnType, @@ -212,9 +237,18 @@ export class RunStaleCellsTool }); }; - // If already finished, return immediately - if (checkAllFinished(store.get(notebookAtom))) { + // Add a small delay after cells finish to allow console outputs to arrive + // Console outputs are streamed and might still be in-flight + const delayForConsoleOutputs = async () => { + if (postExecutionDelay > 0) { + await new Promise((resolve) => setTimeout(resolve, postExecutionDelay)); + } return true; + }; + + // Return immediately if all cells are finished + if (checkAllFinished(store.get(notebookAtom))) { + return await delayForConsoleOutputs(); } // Wait for notebook state changes with timeout @@ -225,7 +259,7 @@ export class RunStaleCellsTool setTimeout(() => reject(new Error("timeout")), timeout), ), ]); - return true; + return await delayForConsoleOutputs(); } catch { return false; } diff --git a/marimo/_ai/_tools/tools/datasource.py b/marimo/_ai/_tools/tools/datasource.py index 941c508ffec..ed8f47da91a 100644 --- a/marimo/_ai/_tools/tools/datasource.py +++ b/marimo/_ai/_tools/tools/datasource.py @@ -42,15 +42,13 @@ class GetDatabaseTables( ToolBase[GetDatabaseTablesArgs, GetDatabaseTablesOutput] ): """ - Get information about tables in a database. Use the query parameter to search by name. Use regex for complex searches. + Get information about tables in a database. Use the query parameter to search by name. You can use regex. Args: session_id: The session id. query (optional): The query to match the database, schemas, and tables. - If a query is provided, it will fuzzy match the query to the database, schemas, and tables available. If no query is provided, all tables are returned. Don't provide a query if you need to see the entire schema view. - - The tables returned contain information about the database, schema and connection name to use in forming SQL queries. + If a query is provided, it will fuzzy match the query to the database, schemas, and tables available. If no query is provided, all tables are returned. """ guidelines = ToolGuidelines( @@ -62,8 +60,10 @@ class GetDatabaseTables( "You must have a valid session id from an active notebook", ], avoid_if=[ - "the user is asking about in-memory DataFrames, use the get_tables_and_variables tool instead", + "You have already been given the schema view, you can refer to the given information", + "The user is asking about in-memory DataFrames, use the get_tables_and_variables tool instead", ], + additional_info="For best results, don't provide a query since you may miss some tables. Alternatively, provide loose queries using regex that can match uppercase/lowercase and plural or singular forms.", ) def handle(self, args: GetDatabaseTablesArgs) -> GetDatabaseTablesOutput: diff --git a/marimo/_server/ai/prompts.py b/marimo/_server/ai/prompts.py index 67d6df137e8..6d94360aedc 100644 --- a/marimo/_server/ai/prompts.py +++ b/marimo/_server/ai/prompts.py @@ -36,7 +36,7 @@ language_rules_multiple_cells: dict[Language, list[str]] = { "sql": [ - 'SQL cells start with df = mo.sql(f"""""") for DuckDB, or df = mo.sql(f"""""", engine=engine) for other SQL engines.', + 'SQL cells start with df = mo.sql(f"""""") for DuckDB, or df = mo.sql(f"""""", engine=engine) for other SQL engines. You should always write queries inline as the code snippet above, do not use variables to store queries.', "This will automatically display the result in the UI. You do not need to return the dataframe in the cell.", "The SQL must use the syntax of the database engine specified in the `engine` variable. If no engine, then use duckdb syntax.", ] @@ -247,6 +247,15 @@ def _get_mode_intro_message(mode: CopilotMode) -> str: elif mode == "agent": return ( f"{base_intro}" + "You are in agent mode - you have autonomy to resolve the user's query by using the tools provided. Please keep going until the user's query is completely resolved, before ending your turn and yielding back to the user. Only terminate your turn when you are sure that the problem is solved. \n" + "\n\n## Agent Mode\n" + "- You are encouraged to edit existing cells in the notebook or add new cells.\n" + "- You should do the following things after editing the notebook:\n" + "\t 1. Use the lint notebook tool to check for errors and lint issues\n" + "\t 2. Run stale cells tool to run the code\n" + "\t 3. If there are errors in cells you have added, edit the existing cell. Don't add new cells to correct errors.\n" + "- If you say you're about to do something, actually do it in the same turn (run the tool call right after).\n" + "- Group code into logical cells, eg. functions should be in separate cells and all the calls will be in one cell. When asked for explanations or summaries, use markdown cells with proper formatting.\n\n" "## Capabilities\n" "- You can use a set of read and write tools to gather additional context from the notebook or environment (e.g., searching code, summarizing data, or reading documentation) and to modify the notebook (e.g., adding cells, editing cells, deleting cells).\n" "## Limitations\n" diff --git a/tests/_server/ai/snapshots/chat_system_prompts.txt b/tests/_server/ai/snapshots/chat_system_prompts.txt index 99c3d312a9c..80b3f0902f6 100644 --- a/tests/_server/ai/snapshots/chat_system_prompts.txt +++ b/tests/_server/ai/snapshots/chat_system_prompts.txt @@ -894,6 +894,18 @@ import numpy as np You are Marimo Copilot, an AI assistant integrated into the marimo notebook code editor. Your primary function is to help users create, analyze, and improve data science notebooks using marimo's reactive programming model. +You are in agent mode - you have autonomy to resolve the user's query by using the tools provided. Please keep going until the user's query is completely resolved, before ending your turn and yielding back to the user. Only terminate your turn when you are sure that the problem is solved. + + +## Agent Mode +- You are encouraged to edit existing cells in the notebook or add new cells. +- You should do the following things after editing the notebook: + 1. Use the lint notebook tool to check for errors and lint issues + 2. Run stale cells tool to run the code + 3. If there are errors in cells you have added, edit the existing cell. Don't add new cells to correct errors. +- If you say you're about to do something, actually do it in the same turn (run the tool call right after). +- Group code into logical cells, eg. functions should be in separate cells and all the calls will be in one cell. When asked for explanations or summaries, use markdown cells with proper formatting. + ## Capabilities - You can use a set of read and write tools to gather additional context from the notebook or environment (e.g., searching code, summarizing data, or reading documentation) and to modify the notebook (e.g., adding cells, editing cells, deleting cells). ## Limitations @@ -1027,7 +1039,7 @@ chart 7. If a variable is already defined, use another name, or make it private by adding an underscore at the beginning. ## Rules for sql: -1. SQL cells start with df = mo.sql(f"""""") for DuckDB, or df = mo.sql(f"""""", engine=engine) for other SQL engines. +1. SQL cells start with df = mo.sql(f"""""") for DuckDB, or df = mo.sql(f"""""", engine=engine) for other SQL engines. You should always write queries inline as the code snippet above, do not use variables to store queries. 2. This will automatically display the result in the UI. You do not need to return the dataframe in the cell. 3. The SQL must use the syntax of the database engine specified in the `engine` variable. If no engine, then use duckdb syntax. diff --git a/tests/_server/ai/snapshots/system_prompts.txt b/tests/_server/ai/snapshots/system_prompts.txt index a6910d62d4e..8902f4cdb9b 100644 --- a/tests/_server/ai/snapshots/system_prompts.txt +++ b/tests/_server/ai/snapshots/system_prompts.txt @@ -347,7 +347,7 @@ Separate logic into multiple cells to keep the code organized and readable. 7. If a variable is already defined, use another name, or make it private by adding an underscore at the beginning. ## Rules for sql: -1. SQL cells start with df = mo.sql(f"""""") for DuckDB, or df = mo.sql(f"""""", engine=engine) for other SQL engines. +1. SQL cells start with df = mo.sql(f"""""") for DuckDB, or df = mo.sql(f"""""", engine=engine) for other SQL engines. You should always write queries inline as the code snippet above, do not use variables to store queries. 2. This will automatically display the result in the UI. You do not need to return the dataframe in the cell. 3. The SQL must use the syntax of the database engine specified in the `engine` variable. If no engine, then use duckdb syntax.