diff --git a/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts b/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts index 24574d379ce..5e7ec21445a 100644 --- a/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts +++ b/frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts @@ -149,6 +149,31 @@ describe("EditNotebookTool", () => { type: "update_cell", previousCode: oldCode, }); + + // Update cell again + const result2 = await tool.handler( + { + edit: { + type: "update_cell", + cellId: cellId1, + code: "x = 3", + }, + }, + toolContext as never, + ); + + expect(result2.status).toBe("success"); + expect(vi.mocked(updateEditorCodeFromPython)).toHaveBeenCalledWith( + editorView, + "x = 3", + ); + + const stagedCells2 = store.get(stagedAICellsAtom); + expect(stagedCells2.has(cellId1)).toBe(true); + expect(stagedCells2.get(cellId1)).toEqual({ + type: "update_cell", + previousCode: oldCode, // Should keep the original code + }); }); it("should throw error when cell ID doesn't exist", async () => { diff --git a/frontend/src/core/ai/tools/__tests__/run-cells-tool.test.ts b/frontend/src/core/ai/tools/__tests__/run-cells-tool.test.ts index c11307576c8..d4299858a78 100644 --- a/frontend/src/core/ai/tools/__tests__/run-cells-tool.test.ts +++ b/frontend/src/core/ai/tools/__tests__/run-cells-tool.test.ts @@ -332,6 +332,45 @@ describe("RunStaleCellsTool", () => { expect(result.cellsToOutput?.[cellId1]?.consoleOutput).toContain("debug"); }); + it("should handle cell output with object output", async () => { + const notebook = MockNotebook.notebookState({ + cellData: { + [cellId1]: { code: "x = {'a': 1, 'b': 2}", edited: true }, + }, + }); + store.set(notebookAtom, notebook); + + // Mock runCells + vi.mocked(runCells).mockImplementation(async () => { + const updatedNotebook = store.get(notebookAtom); + updatedNotebook.cellRuntime[cellId1] = { + ...updatedNotebook.cellRuntime[cellId1], + status: "idle", + }; + store.set(notebookAtom, updatedNotebook); + }); + + // Mock getCellContextData to return object output + vi.mocked(getCellContextData).mockReturnValue({ + cellOutput: { + outputType: "text", + processedContent: null, + imageUrl: null, + output: { data: JSON.stringify({ a: 1, b: 2 }) }, + }, + consoleOutputs: null, + cellName: "cell1", + } as never); + + const result = await tool.handler({}, toolContext as never); + + expect(result.status).toBe("success"); + expect(result.cellsToOutput).toBeDefined(); + expect(result.cellsToOutput?.[cellId1]?.cellOutput).toEqual( + 'Output:\n{"a":1,"b":2}', + ); + }); + it("should return success when all stale cells have no output", async () => { const notebook = MockNotebook.notebookState({ cellData: { diff --git a/frontend/src/core/ai/tools/edit-notebook-tool.ts b/frontend/src/core/ai/tools/edit-notebook-tool.ts index bccfb477e84..3fd29fec09c 100644 --- a/frontend/src/core/ai/tools/edit-notebook-tool.ts +++ b/frontend/src/core/ai/tools/edit-notebook-tool.ts @@ -11,6 +11,7 @@ import { import { CellId } from "@/core/cells/ids"; import { updateEditorCodeFromPython } from "@/core/codemirror/language/utils"; import type { CellColumnId } from "@/utils/id-tree"; +import { stagedAICellsAtom } from "../staged-cells"; import { type AiTool, type ToolDescription, @@ -23,9 +24,9 @@ import type { CopilotMode } from "./registry"; const description: ToolDescription = { baseDescription: - "Perform editing operations on the current notebook. Call this tool multiple times to perform multiple edits.", + "Perform editing operations on the current notebook. You should prefer to create new cells unless you need to edit existing cells. Call this tool multiple times to perform multiple edits. Separate code into logical individual cells to take advantage of the notebook's reactive execution model.", prerequisites: [ - "Find out the cellIds and columnIds first (call lightweight cell map tool)", + "If you are updating existing cells, you need the cellIds or columnIds. If they are not known, call the lightweight_cell_map_tool to find out.", ], additionalInfo: ` Args: @@ -108,10 +109,19 @@ export class EditNotebookTool scrollAndHighlightCell(cellId); + // If previous code exists, we don't want to replace it, it means there is a new edit on top of the previous edit + // Keep the original code + const stagedCell = store.get(stagedAICellsAtom).get(cellId); const currentCellCode = editorView.state.doc.toString(); + const previousCode = + stagedCell?.type === "update_cell" || + stagedCell?.type === "delete_cell" + ? stagedCell.previousCode + : currentCellCode; + addStagedCell({ cellId, - edit: { type: "update_cell", previousCode: currentCellCode }, + edit: { type: "update_cell", previousCode: previousCode }, }); updateEditorCodeFromPython(editorView, code); diff --git a/frontend/src/core/ai/tools/run-cells-tool.ts b/frontend/src/core/ai/tools/run-cells-tool.ts index 2c162292ae2..ac6796d4da7 100644 --- a/frontend/src/core/ai/tools/run-cells-tool.ts +++ b/frontend/src/core/ai/tools/run-cells-tool.ts @@ -165,14 +165,24 @@ export class RunStaleCellsTool status: "success", cellsToOutput: Object.fromEntries(cellsToOutput), message: resultMessage === "" ? undefined : resultMessage, + next_steps: [ + "Review the output of the cells, if you need to make any changes, you can run this tool again, the CellId is the key of the result object if you want to edit the cell.", + ], }; }; private formatOutputString(cellOutput: BaseOutput): string { let outputString = ""; const { outputType, processedContent, imageUrl, output } = cellOutput; - if (outputType === "text" && processedContent) { - outputString += `Output:\n${processedContent}`; + if (outputType === "text") { + outputString += "Output:\n"; + if (processedContent) { + outputString += processedContent; + } else if (typeof output.data === "string") { + outputString += output.data; + } else { + outputString += JSON.stringify(output.data); + } } else if (outputType === "media") { outputString += `Media Output: Contains ${output.mimetype} content`; if (imageUrl) { diff --git a/marimo/_ai/_tools/tools/datasource.py b/marimo/_ai/_tools/tools/datasource.py index 4225af44feb..941c508ffec 100644 --- a/marimo/_ai/_tools/tools/datasource.py +++ b/marimo/_ai/_tools/tools/datasource.py @@ -11,6 +11,7 @@ from marimo._ai._tools.utils.exceptions import ToolExecutionError from marimo._data.models import DataTable from marimo._server.sessions import Session +from marimo._sql.engines.duckdb import INTERNAL_DUCKDB_ENGINE from marimo._types.ids import SessionId from marimo._utils.fuzzy_match import compile_regex, is_fuzzy_match @@ -29,6 +30,7 @@ class TableDetails: database: str schema: str table: DataTable + sample_query: str @dataclass @@ -93,19 +95,30 @@ def _get_tables( for connection in data_connectors.connections: for database in connection.databases: + default_database = connection.default_database == database.name for schema in database.schemas: + default_schema = connection.default_schema == schema.name # If query is None, match all schemas # If matching, add all tables to the list if query is None or is_fuzzy_match( query, schema.name, compiled_pattern, is_regex ): for table in schema.tables: + sample_query = self._form_sample_query( + database=database.name, + schema=schema.name, + table=table.name, + default_database=default_database, + default_schema=default_schema, + engine=connection.name, + ) tables.append( TableDetails( connection=connection.name, database=database.name, schema=schema.name, table=table, + sample_query=sample_query, ) ) continue @@ -113,18 +126,50 @@ def _get_tables( if is_fuzzy_match( query, table.name, compiled_pattern, is_regex ): + sample_query = self._form_sample_query( + database=database.name, + schema=schema.name, + table=table.name, + default_database=default_database, + default_schema=default_schema, + engine=connection.name, + ) tables.append( TableDetails( connection=connection.name, database=database.name, schema=schema.name, table=table, + sample_query=sample_query, ) ) return GetDatabaseTablesOutput( tables=tables, next_steps=[ - 'Example of an SQL query: _df = mo.sql(f"""SELECT * FROM database.schema.name LIMIT 100""")', + "Use the sample query as a guideline to write your own SQL query." ], ) + + def _form_sample_query( + self, + *, + database: str, + schema: str, + table: str, + default_database: bool, + default_schema: bool, + engine: str, + ) -> str: + sample_query = f"SELECT * FROM {database}.{schema}.{table} LIMIT 100" + if default_database: + sample_query = f"SELECT * FROM {schema}.{table} LIMIT 100" + if default_schema: + sample_query = f"SELECT * FROM {table} LIMIT 100" + if engine != INTERNAL_DUCKDB_ENGINE: + wrapped_query = ( + f'df = mo.sql(f"""{sample_query}""", engine={engine})' + ) + else: + wrapped_query = f'df = mo.sql(f"""{sample_query}""")' + return wrapped_query diff --git a/tests/_ai/tools/tools/test_datasource_tool.py b/tests/_ai/tools/tools/test_datasource_tool.py index a1480eb1b5c..73357dcffd6 100644 --- a/tests/_ai/tools/tools/test_datasource_tool.py +++ b/tests/_ai/tools/tools/test_datasource_tool.py @@ -3,6 +3,7 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Optional import pytest @@ -15,6 +16,7 @@ from marimo._ai._tools.utils.exceptions import ToolExecutionError from marimo._data.models import Database, DataTable, DataTableColumn, Schema from marimo._messaging.ops import DataSourceConnections +from marimo._sql.engines.duckdb import INTERNAL_DUCKDB_ENGINE @dataclass @@ -22,6 +24,8 @@ class MockDataSourceConnection: name: str dialect: str databases: list[Database] + default_database: Optional[str] = None + default_schema: Optional[str] = None @dataclass @@ -51,9 +55,11 @@ def sample_table() -> DataTable: num_columns=3, variable_name=None, columns=[ - DataTableColumn("id", "int", "INTEGER", [1, 2, 3]), - DataTableColumn("name", "str", "VARCHAR", ["Alice", "Bob"]), - DataTableColumn("email", "str", "VARCHAR", ["alice@example.com"]), + DataTableColumn("id", "integer", "INTEGER", [1, 2, 3]), + DataTableColumn("name", "string", "VARCHAR", ["Alice", "Bob"]), + DataTableColumn( + "email", "string", "VARCHAR", ["alice@example.com"] + ), ], ) @@ -111,8 +117,8 @@ def multi_table_session() -> MockSession: num_columns=2, variable_name=None, columns=[ - DataTableColumn("id", "int", "INTEGER", [1, 2]), - DataTableColumn("name", "str", "VARCHAR", ["Alice"]), + DataTableColumn("id", "integer", "INTEGER", [1, 2]), + DataTableColumn("name", "string", "VARCHAR", ["Alice"]), ], ), DataTable( @@ -123,8 +129,8 @@ def multi_table_session() -> MockSession: num_columns=2, variable_name=None, columns=[ - DataTableColumn("order_id", "int", "INTEGER", [1]), - DataTableColumn("user_id", "int", "INTEGER", [1]), + DataTableColumn("order_id", "integer", "INTEGER", [1]), + DataTableColumn("user_id", "integer", "INTEGER", [1]), ], ), DataTable( @@ -135,8 +141,8 @@ def multi_table_session() -> MockSession: num_columns=2, variable_name=None, columns=[ - DataTableColumn("product_id", "int", "INTEGER", [1]), - DataTableColumn("name", "str", "VARCHAR", ["Widget"]), + DataTableColumn("product_id", "integer", "INTEGER", [1]), + DataTableColumn("name", "string", "VARCHAR", ["Widget"]), ], ), ] @@ -549,3 +555,108 @@ def test_query_no_duplicates(tool: GetDatabaseTables): assert "user_profiles" in table_names assert "user_settings" in table_names assert "user_reviews" in table_names + + +def test_form_sample_query_full_qualified(tool: GetDatabaseTables): + """Test forming a sample query with full qualified name (not default database or schema).""" + query = tool._form_sample_query( + database="mydb", + schema="myschema", + table="mytable", + default_database=False, + default_schema=False, + engine="postgres_conn", + ) + + assert ( + query + == 'df = mo.sql(f"""SELECT * FROM mydb.myschema.mytable LIMIT 100""", engine=postgres_conn)' + ) + + +def test_form_sample_query_default_database(tool: GetDatabaseTables): + """Test forming a sample query when database is default (schema.table).""" + query = tool._form_sample_query( + database="mydb", + schema="myschema", + table="mytable", + default_database=True, + default_schema=False, + engine="mysql_conn", + ) + + assert ( + query + == 'df = mo.sql(f"""SELECT * FROM myschema.mytable LIMIT 100""", engine=mysql_conn)' + ) + + +def test_form_sample_query_default_schema(tool: GetDatabaseTables): + """Test forming a sample query when schema is default (table only).""" + query = tool._form_sample_query( + database="mydb", + schema="myschema", + table="mytable", + default_database=False, + default_schema=True, + engine="postgres_conn", + ) + + assert ( + query + == 'df = mo.sql(f"""SELECT * FROM mytable LIMIT 100""", engine=postgres_conn)' + ) + + +def test_form_sample_query_both_defaults(tool: GetDatabaseTables): + """Test forming a sample query when both database and schema are default.""" + query = tool._form_sample_query( + database="mydb", + schema="myschema", + table="mytable", + default_database=True, + default_schema=True, + engine="mysql_conn", + ) + + assert ( + query + == 'df = mo.sql(f"""SELECT * FROM mytable LIMIT 100""", engine=mysql_conn)' + ) + + +def test_form_sample_query_internal_duckdb_no_defaults( + tool: GetDatabaseTables, +): + """Test forming a sample query with internal DuckDB engine (no engine parameter).""" + + query = tool._form_sample_query( + database="mydb", + schema="myschema", + table="mytable", + default_database=False, + default_schema=False, + engine=INTERNAL_DUCKDB_ENGINE, + ) + + assert ( + query + == 'df = mo.sql(f"""SELECT * FROM mydb.myschema.mytable LIMIT 100""")' + ) + + +def test_form_sample_query_internal_duckdb_with_defaults( + tool: GetDatabaseTables, +): + """Test forming a sample query with internal DuckDB engine and both defaults.""" + + query = tool._form_sample_query( + database="mydb", + schema="myschema", + table="mytable", + default_database=True, + default_schema=True, + engine=INTERNAL_DUCKDB_ENGINE, + ) + + assert query == 'df = mo.sql(f"""SELECT * FROM mytable LIMIT 100""")'