Skip to content

Commit f946a34

Browse files
authored
add sample query, better tool descriptions to improve agentic mode (#6880)
## 📝 Summary <!-- Provide a concise summary of what this pull request is addressing. If this PR fixes any issues, list them here by number (e.g., Fixes #123). --> https://github.com/user-attachments/assets/6658135d-a2b2-4286-b804-9653dcc10b94 Updates the logic of handling edit history, keeps the original code instead of latest edit as `previousCode` ## 🔍 Description of Changes <!-- Detail the specific changes made in this pull request. Explain the problem addressed and how it was resolved. If applicable, provide before and after comparisons, screenshots, or any relevant details to help reviewers understand the changes easily. --> ## 📋 Checklist - [x] I have read the [contributor guidelines](https://github.com/marimo-team/marimo/blob/main/CONTRIBUTING.md). - [ ] For large changes, or changes that affect the public API: this change was discussed or approved through an issue, on [Discord](https://marimo.io/discord?ref=pr), or the community [discussions](https://github.com/marimo-team/marimo/discussions) (Please provide a link if applicable). - [ ] I have added tests for the changes made. - [x] I have run the code and verified that it works as expected.
1 parent 306dc27 commit f946a34

File tree

6 files changed

+255
-15
lines changed

6 files changed

+255
-15
lines changed

frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,31 @@ describe("EditNotebookTool", () => {
149149
type: "update_cell",
150150
previousCode: oldCode,
151151
});
152+
153+
// Update cell again
154+
const result2 = await tool.handler(
155+
{
156+
edit: {
157+
type: "update_cell",
158+
cellId: cellId1,
159+
code: "x = 3",
160+
},
161+
},
162+
toolContext as never,
163+
);
164+
165+
expect(result2.status).toBe("success");
166+
expect(vi.mocked(updateEditorCodeFromPython)).toHaveBeenCalledWith(
167+
editorView,
168+
"x = 3",
169+
);
170+
171+
const stagedCells2 = store.get(stagedAICellsAtom);
172+
expect(stagedCells2.has(cellId1)).toBe(true);
173+
expect(stagedCells2.get(cellId1)).toEqual({
174+
type: "update_cell",
175+
previousCode: oldCode, // Should keep the original code
176+
});
152177
});
153178

154179
it("should throw error when cell ID doesn't exist", async () => {

frontend/src/core/ai/tools/__tests__/run-cells-tool.test.ts

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,45 @@ describe("RunStaleCellsTool", () => {
332332
expect(result.cellsToOutput?.[cellId1]?.consoleOutput).toContain("debug");
333333
});
334334

335+
it("should handle cell output with object output", async () => {
336+
const notebook = MockNotebook.notebookState({
337+
cellData: {
338+
[cellId1]: { code: "x = {'a': 1, 'b': 2}", edited: true },
339+
},
340+
});
341+
store.set(notebookAtom, notebook);
342+
343+
// Mock runCells
344+
vi.mocked(runCells).mockImplementation(async () => {
345+
const updatedNotebook = store.get(notebookAtom);
346+
updatedNotebook.cellRuntime[cellId1] = {
347+
...updatedNotebook.cellRuntime[cellId1],
348+
status: "idle",
349+
};
350+
store.set(notebookAtom, updatedNotebook);
351+
});
352+
353+
// Mock getCellContextData to return object output
354+
vi.mocked(getCellContextData).mockReturnValue({
355+
cellOutput: {
356+
outputType: "text",
357+
processedContent: null,
358+
imageUrl: null,
359+
output: { data: JSON.stringify({ a: 1, b: 2 }) },
360+
},
361+
consoleOutputs: null,
362+
cellName: "cell1",
363+
} as never);
364+
365+
const result = await tool.handler({}, toolContext as never);
366+
367+
expect(result.status).toBe("success");
368+
expect(result.cellsToOutput).toBeDefined();
369+
expect(result.cellsToOutput?.[cellId1]?.cellOutput).toEqual(
370+
'Output:\n{"a":1,"b":2}',
371+
);
372+
});
373+
335374
it("should return success when all stale cells have no output", async () => {
336375
const notebook = MockNotebook.notebookState({
337376
cellData: {

frontend/src/core/ai/tools/edit-notebook-tool.ts

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import {
1111
import { CellId } from "@/core/cells/ids";
1212
import { updateEditorCodeFromPython } from "@/core/codemirror/language/utils";
1313
import type { CellColumnId } from "@/utils/id-tree";
14+
import { stagedAICellsAtom } from "../staged-cells";
1415
import {
1516
type AiTool,
1617
type ToolDescription,
@@ -23,9 +24,9 @@ import type { CopilotMode } from "./registry";
2324

2425
const description: ToolDescription = {
2526
baseDescription:
26-
"Perform editing operations on the current notebook. Call this tool multiple times to perform multiple edits.",
27+
"Perform editing operations on the current notebook. You should prefer to create new cells unless you need to edit existing cells. Call this tool multiple times to perform multiple edits. Separate code into logical individual cells to take advantage of the notebook's reactive execution model.",
2728
prerequisites: [
28-
"Find out the cellIds and columnIds first (call lightweight cell map tool)",
29+
"If you are updating existing cells, you need the cellIds or columnIds. If they are not known, call the lightweight_cell_map_tool to find out.",
2930
],
3031
additionalInfo: `
3132
Args:
@@ -108,10 +109,19 @@ export class EditNotebookTool
108109

109110
scrollAndHighlightCell(cellId);
110111

112+
// If previous code exists, we don't want to replace it, it means there is a new edit on top of the previous edit
113+
// Keep the original code
114+
const stagedCell = store.get(stagedAICellsAtom).get(cellId);
111115
const currentCellCode = editorView.state.doc.toString();
116+
const previousCode =
117+
stagedCell?.type === "update_cell" ||
118+
stagedCell?.type === "delete_cell"
119+
? stagedCell.previousCode
120+
: currentCellCode;
121+
112122
addStagedCell({
113123
cellId,
114-
edit: { type: "update_cell", previousCode: currentCellCode },
124+
edit: { type: "update_cell", previousCode: previousCode },
115125
});
116126

117127
updateEditorCodeFromPython(editorView, code);

frontend/src/core/ai/tools/run-cells-tool.ts

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,24 @@ export class RunStaleCellsTool
165165
status: "success",
166166
cellsToOutput: Object.fromEntries(cellsToOutput),
167167
message: resultMessage === "" ? undefined : resultMessage,
168+
next_steps: [
169+
"Review the output of the cells, if you need to make any changes, you can run this tool again, the CellId is the key of the result object if you want to edit the cell.",
170+
],
168171
};
169172
};
170173

171174
private formatOutputString(cellOutput: BaseOutput): string {
172175
let outputString = "";
173176
const { outputType, processedContent, imageUrl, output } = cellOutput;
174-
if (outputType === "text" && processedContent) {
175-
outputString += `Output:\n${processedContent}`;
177+
if (outputType === "text") {
178+
outputString += "Output:\n";
179+
if (processedContent) {
180+
outputString += processedContent;
181+
} else if (typeof output.data === "string") {
182+
outputString += output.data;
183+
} else {
184+
outputString += JSON.stringify(output.data);
185+
}
176186
} else if (outputType === "media") {
177187
outputString += `Media Output: Contains ${output.mimetype} content`;
178188
if (imageUrl) {

marimo/_ai/_tools/tools/datasource.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from marimo._ai._tools.utils.exceptions import ToolExecutionError
1212
from marimo._data.models import DataTable
1313
from marimo._server.sessions import Session
14+
from marimo._sql.engines.duckdb import INTERNAL_DUCKDB_ENGINE
1415
from marimo._types.ids import SessionId
1516
from marimo._utils.fuzzy_match import compile_regex, is_fuzzy_match
1617

@@ -29,6 +30,7 @@ class TableDetails:
2930
database: str
3031
schema: str
3132
table: DataTable
33+
sample_query: str
3234

3335

3436
@dataclass
@@ -93,38 +95,81 @@ def _get_tables(
9395

9496
for connection in data_connectors.connections:
9597
for database in connection.databases:
98+
default_database = connection.default_database == database.name
9699
for schema in database.schemas:
100+
default_schema = connection.default_schema == schema.name
97101
# If query is None, match all schemas
98102
# If matching, add all tables to the list
99103
if query is None or is_fuzzy_match(
100104
query, schema.name, compiled_pattern, is_regex
101105
):
102106
for table in schema.tables:
107+
sample_query = self._form_sample_query(
108+
database=database.name,
109+
schema=schema.name,
110+
table=table.name,
111+
default_database=default_database,
112+
default_schema=default_schema,
113+
engine=connection.name,
114+
)
103115
tables.append(
104116
TableDetails(
105117
connection=connection.name,
106118
database=database.name,
107119
schema=schema.name,
108120
table=table,
121+
sample_query=sample_query,
109122
)
110123
)
111124
continue
112125
for table in schema.tables:
113126
if is_fuzzy_match(
114127
query, table.name, compiled_pattern, is_regex
115128
):
129+
sample_query = self._form_sample_query(
130+
database=database.name,
131+
schema=schema.name,
132+
table=table.name,
133+
default_database=default_database,
134+
default_schema=default_schema,
135+
engine=connection.name,
136+
)
116137
tables.append(
117138
TableDetails(
118139
connection=connection.name,
119140
database=database.name,
120141
schema=schema.name,
121142
table=table,
143+
sample_query=sample_query,
122144
)
123145
)
124146

125147
return GetDatabaseTablesOutput(
126148
tables=tables,
127149
next_steps=[
128-
'Example of an SQL query: _df = mo.sql(f"""SELECT * FROM database.schema.name LIMIT 100""")',
150+
"Use the sample query as a guideline to write your own SQL query."
129151
],
130152
)
153+
154+
def _form_sample_query(
155+
self,
156+
*,
157+
database: str,
158+
schema: str,
159+
table: str,
160+
default_database: bool,
161+
default_schema: bool,
162+
engine: str,
163+
) -> str:
164+
sample_query = f"SELECT * FROM {database}.{schema}.{table} LIMIT 100"
165+
if default_database:
166+
sample_query = f"SELECT * FROM {schema}.{table} LIMIT 100"
167+
if default_schema:
168+
sample_query = f"SELECT * FROM {table} LIMIT 100"
169+
if engine != INTERNAL_DUCKDB_ENGINE:
170+
wrapped_query = (
171+
f'df = mo.sql(f"""{sample_query}""", engine={engine})'
172+
)
173+
else:
174+
wrapped_query = f'df = mo.sql(f"""{sample_query}""")'
175+
return wrapped_query

0 commit comments

Comments
 (0)