Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions frontend/src/components/editor/ai/ai-completion-editor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
// Use complete to pass the prompt directly, else input might be empty
complete(initialPrompt);
}
// eslint-disable-next-line react-hooks/exhaustive-deps

Check warning on line 157 in frontend/src/components/editor/ai/ai-completion-editor.tsx

View workflow job for this annotation

GitHub Actions / 🧹 Lint frontend

React Compiler has skipped optimizing this component because one or more React ESLint rules were disabled. React Compiler only works when your components follow all the rules of React, disabling them may result in unexpected or incorrect behavior
}, [triggerImmediately]);

// Focus the input
Expand Down Expand Up @@ -409,14 +409,14 @@
<div className="flex flex-row items-center gap-2">
{isLoading ? (
<Loader2Icon
className="animate-spin text-blue-600 mb-[1px]"
className="animate-spin text-blue-600 mb-px"
size={15}
strokeWidth={2}
aria-label="Generating fix"
/>
) : (
<CircleCheckIcon
className="text-green-600 mb-[1px]"
className="text-green-600 mb-px"
size={15}
strokeWidth={2}
aria-label="Fix generated"
Expand Down
25 changes: 14 additions & 11 deletions frontend/src/components/editor/output/ConsoleOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,7 @@ const ConsoleOutputInternal = (props: Props): React.ReactNode => {
const getOutputString = (): string => {
const text = consoleOutputs
.filter((output) => output.channel !== "pdb")
.map((output) => {
if (
output.mimetype.startsWith("application/vnd.marimo") ||
output.mimetype === "text/html"
) {
return parseHtmlContent(Strings.asString(output.data));
}

// Convert ANSI to HTML, then parse as HTML
return ansiToPlainText(Strings.asString(output.data));
})
.map((output) => processOutput(output))
.join("\n");
return text;
};
Expand Down Expand Up @@ -282,3 +272,16 @@ const renderText = (text: string | null) => {
<span dangerouslySetInnerHTML={{ __html: ansiUp.ansi_to_html(text) }} />
);
};

/** Convert cell or console output to a string, while handling html and ansi codes */
export const processOutput = (output: OutputMessage): string => {
if (
output.mimetype.startsWith("application/vnd.marimo") ||
output.mimetype === "text/html"
) {
return parseHtmlContent(Strings.asString(output.data));
}

// Convert ANSI to HTML, then parse as HTML
return ansiToPlainText(Strings.asString(output.data));
};
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* Copyright 2024 Marimo. All rights reserved. */

import { beforeEach, describe, expect, it, vi } from "vitest";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { Mocks } from "@/__mocks__/common";

// Mock the external dependencies
Expand Down Expand Up @@ -320,8 +320,12 @@ describe("Cell output utility functions", () => {
describe("HTML content parsing", () => {
let provider: CellOutputContextProvider;
let mockStore: JotaiStore;
let originalCreateElement: typeof document.createElement;

beforeEach(() => {
// Save original createElement
originalCreateElement = document.createElement;

// Mock DOM methods for HTML parsing
const mockDiv = {
innerHTML: "",
Expand All @@ -343,6 +347,13 @@ describe("Cell output utility functions", () => {
provider = new CellOutputContextProvider(mockStore);
});

afterEach(() => {
// Restore original createElement to prevent mock pollution
if (originalCreateElement) {
document.createElement = originalCreateElement;
}
});

it("should extract text content from HTML", () => {
const htmlContent = "<p>Hello <strong>world</strong>!</p>";

Expand Down
10 changes: 2 additions & 8 deletions frontend/src/core/ai/context/providers/cell-output.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import type { Completion } from "@codemirror/autocomplete";
import type { FileUIPart } from "ai";
import { toPng } from "html-to-image";
import { processOutput } from "@/components/editor/output/ConsoleOutput";
import { type NotebookState, notebookAtom } from "@/core/cells/cells";
import { type CellId, CellOutputId } from "@/core/cells/ids";
import { displayCellName } from "@/core/cells/names";
import { isOutputEmpty } from "@/core/cells/outputs";
import type { OutputMessage } from "@/core/kernel/messages";
import type { JotaiStore } from "@/core/state/jotai";
import { parseHtmlContent } from "@/utils/dom";
import { Logger } from "@/utils/Logger";
import { type AIContextItem, AIContextProvider } from "../registry";
import { contextToXml } from "../utils";
Expand Down Expand Up @@ -320,16 +320,10 @@ function getBaseOutput(output: OutputMessage): BaseOutput | null {
const isMedia = isMediaMimetype(mimetype, String(output.data));
const outputType = isMedia ? "media" : "text";

let processedContent: string | undefined;
const processedContent = processOutput(output);
let imageUrl: string | undefined;
let shouldDownloadImage = false;

// Process text content
if (outputType === "text" && typeof output.data === "string") {
processedContent =
mimetype === "text/html" ? parseHtmlContent(output.data) : output.data;
}

// Process media content - for now, we'll just note that it's media
if (outputType === "media") {
if (typeof output.data === "string" && output.data.startsWith("data:")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ describe("RunStaleCellsTool", () => {
store,
};

tool = new RunStaleCellsTool();
tool = new RunStaleCellsTool({ postExecutionDelay: 0 });

cellId1 = "cell-1" as CellId;
cellId2 = "cell-2" as CellId;
Expand Down
22 changes: 15 additions & 7 deletions frontend/src/core/ai/tools/edit-notebook-tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,19 @@ export class EditNotebookTool

scrollAndHighlightCell(cellId);

// If previous code exists, we don't want to replace it, it means there is a new edit on top of the previous edit
const existingStagedCell = store.get(stagedAICellsAtom).get(cellId);

// If previous edit was from a new cell, just replace editor code with the new code
if (existingStagedCell?.type === "add_cell") {
updateEditorCodeFromPython(editorView, code);
break;
}

// If previous code exists, we don't want to replace it, it means there is a new edit on top of the previous update
// Keep the original code
const stagedCell = store.get(stagedAICellsAtom).get(cellId);
const currentCellCode = editorView.state.doc.toString();
const previousCode =
stagedCell?.type === "update_cell" ||
stagedCell?.type === "delete_cell"
? stagedCell.previousCode
: currentCellCode;
existingStagedCell?.previousCode ?? currentCellCode;

addStagedCell({
cellId,
Expand Down Expand Up @@ -186,7 +190,11 @@ export class EditNotebookTool
}
return {
status: "success",
next_steps: ["If you need to perform more edits, call this tool again."],
next_steps: [
"If you need to perform more edits, call this tool again.",
"You should use the lint notebook tool to check for errors and lint issues. Fix them by editing the notebook.",
"You should use the run stale cells tool to run the cells that have been edited or newly added. This allows you to see the output of the cells and fix any errors.",
],
};
};

Expand Down
86 changes: 60 additions & 26 deletions frontend/src/core/ai/tools/run-cells-tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ import {
} from "./base";
import type { CopilotMode } from "./registry";

const POST_EXECUTION_DELAY = 200;
const WAIT_FOR_CELLS_TIMEOUT = 30_000;

interface CellOutput {
consoleOutput?: string;
// consoleAttachments?: FileUIPart[];
cellOutput?: string;
// cellAttachments?: FileUIPart[];
}

// Must use Record instead of Map because Map serializes to JSON as {}
Expand Down Expand Up @@ -67,6 +68,12 @@ export class RunStaleCellsTool
}) satisfies z.ZodType<RunStaleCellsOutput>;
readonly mode: CopilotMode[] = ["agent"];

private readonly postExecutionDelay: number;

constructor(opts?: { postExecutionDelay?: number }) {
this.postExecutionDelay = opts?.postExecutionDelay ?? POST_EXECUTION_DELAY;
}

handler = async (
_args: EmptyToolInput,
toolContext: ToolNotebookContext,
Expand All @@ -91,7 +98,12 @@ export class RunStaleCellsTool
});

// Wait for all cells to finish executing
const allCellsFinished = await this.waitForCellsToFinish(store, staleCells);
const allCellsFinished = await this.waitForCellsToFinish(
store,
staleCells,
WAIT_FOR_CELLS_TIMEOUT,
this.postExecutionDelay,
);
if (!allCellsFinished) {
return {
status: "success",
Expand All @@ -105,52 +117,48 @@ export class RunStaleCellsTool

const cellsToOutput = new Map<CellId, CellOutput | null>();
let resultMessage = "";
let outputHasErrors = false;

for (const cellId of staleCells) {
const cellContextData = getCellContextData(cellId, updatedNotebook, {
includeConsoleOutput: true,
});

let cellOutputString: string | undefined;
// let cellAttachments: FileUIPart[] | undefined;
let consoleOutputString: string | undefined;
// let consoleAttachments: FileUIPart[] | undefined;

const cellOutput = cellContextData.cellOutput;
const consoleOutputs = cellContextData.consoleOutputs;
if (!cellOutput && !consoleOutputs) {
const hasConsoleOutput = consoleOutputs && consoleOutputs.length > 0;

if (!cellOutput && !hasConsoleOutput) {
// Set null to show no output
cellsToOutput.set(cellId, null);
continue;
}

if (cellOutput) {
cellOutputString = this.formatOutputString(cellOutput);
// cellAttachments = await getAttachmentsForOutputs(
// [cellOutput],
// cellId,
// cellContextData.cellName,
// );
if (this.outputHasErrors(cellOutput)) {
outputHasErrors = true;
}
}

if (consoleOutputs) {
// consoleAttachments = await getAttachmentsForOutputs(
// consoleOutputs,
// cellId,
// cellContextData.cellName,
// );
if (hasConsoleOutput) {
consoleOutputString = consoleOutputs
.map((output) => this.formatOutputString(output))
.join("\n");
resultMessage +=
"Console output represents the stdout or stderr of the cell (eg. print statements).";

if (consoleOutputs.some((output) => this.outputHasErrors(output))) {
outputHasErrors = true;
}
}

cellsToOutput.set(cellId, {
cellOutput: cellOutputString,
consoleOutput: consoleOutputString,
// cellAttachments: cellAttachments,
// consoleAttachments: consoleAttachments,
});
}

Expand All @@ -161,16 +169,32 @@ export class RunStaleCellsTool
};
}

const nextSteps = [
"Review the output of the cells. The CellId is the key of the result object.",
outputHasErrors
? "There are errors in the cells. Please fix them by using the edit notebook tool and the given CellIds."
: "You may edit the notebook further with the given CellIds.",
];

return {
status: "success",
cellsToOutput: Object.fromEntries(cellsToOutput),
message: resultMessage === "" ? undefined : resultMessage,
next_steps: [
"Review the output of the cells, if you need to make any changes, you can run this tool again, the CellId is the key of the result object if you want to edit the cell.",
],
next_steps: nextSteps,
};
};

private outputHasErrors(cellOutput: BaseOutput): boolean {
const { output } = cellOutput;
if (
output.mimetype === "application/vnd.marimo+error" ||
output.mimetype === "application/vnd.marimo+traceback"
) {
return true;
}
return false;
}

private formatOutputString(cellOutput: BaseOutput): string {
let outputString = "";
const { outputType, processedContent, imageUrl, output } = cellOutput;
Expand Down Expand Up @@ -199,7 +223,8 @@ export class RunStaleCellsTool
private async waitForCellsToFinish(
store: JotaiStore,
cellIds: CellId[],
timeout = 30_000,
timeout: number,
postExecutionDelay: number,
): Promise<boolean> {
const checkAllFinished = (
notebook: ReturnType<typeof notebookAtom.read>,
Expand All @@ -212,9 +237,18 @@ export class RunStaleCellsTool
});
};

// If already finished, return immediately
if (checkAllFinished(store.get(notebookAtom))) {
// Add a small delay after cells finish to allow console outputs to arrive
// Console outputs are streamed and might still be in-flight
const delayForConsoleOutputs = async () => {
if (postExecutionDelay > 0) {
await new Promise((resolve) => setTimeout(resolve, postExecutionDelay));
}
return true;
};

// Return immediately if all cells are finished
if (checkAllFinished(store.get(notebookAtom))) {
return await delayForConsoleOutputs();
}

// Wait for notebook state changes with timeout
Expand All @@ -225,7 +259,7 @@ export class RunStaleCellsTool
setTimeout(() => reject(new Error("timeout")), timeout),
),
]);
return true;
return await delayForConsoleOutputs();
} catch {
return false;
}
Expand Down
10 changes: 5 additions & 5 deletions marimo/_ai/_tools/tools/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,13 @@ class GetDatabaseTables(
ToolBase[GetDatabaseTablesArgs, GetDatabaseTablesOutput]
):
"""
Get information about tables in a database. Use the query parameter to search by name. Use regex for complex searches.
Get information about tables in a database. Use the query parameter to search by name. You can use regex.

Args:
session_id: The session id.
query (optional): The query to match the database, schemas, and tables.

If a query is provided, it will fuzzy match the query to the database, schemas, and tables available. If no query is provided, all tables are returned. Don't provide a query if you need to see the entire schema view.

The tables returned contain information about the database, schema and connection name to use in forming SQL queries.
If a query is provided, it will fuzzy match the query to the database, schemas, and tables available. If no query is provided, all tables are returned.
"""

guidelines = ToolGuidelines(
Expand All @@ -62,8 +60,10 @@ class GetDatabaseTables(
"You must have a valid session id from an active notebook",
],
avoid_if=[
"the user is asking about in-memory DataFrames, use the get_tables_and_variables tool instead",
"You have already been given the schema view, you can refer to the given information",
"The user is asking about in-memory DataFrames, use the get_tables_and_variables tool instead",
],
additional_info="For best results, don't provide a query since you may miss some tables. Alternatively, provide loose queries using regex that can match uppercase/lowercase and plural or singular forms.",
)

def handle(self, args: GetDatabaseTablesArgs) -> GetDatabaseTablesOutput:
Expand Down
Loading
Loading