Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import { notebookAtom } from "@/core/cells/cells";
import type { CellId } from "@/core/cells/ids";
import { updateEditorCodeFromPython } from "@/core/codemirror/language/utils";
import { OverridingHotkeyProvider } from "@/core/hotkeys/hotkeys";
import type { CellColumnId } from "@/utils/id-tree";
import { MultiColumn } from "@/utils/id-tree";
import { cellConfigExtension } from "../../../codemirror/config/extension";
import { adaptiveLanguageConfiguration } from "../../../codemirror/language/extension";
Expand Down Expand Up @@ -250,7 +249,7 @@ describe("EditNotebookTool", () => {
{
edit: {
type: "add_cell",
position: "__end__",
position: "end",
code: newCode,
},
},
Expand Down Expand Up @@ -334,15 +333,14 @@ describe("EditNotebookTool", () => {
});
// Create multi-column layout
notebook.cellIds = MultiColumn.from([[cellId1], [cellId2]]);
const columnId = notebook.cellIds.getColumns()[1].id;
store.set(notebookAtom, notebook);

const newCode = "y = 2";
const result = await tool.handler(
{
edit: {
type: "add_cell",
position: { type: "__end__", columnId },
position: { type: "end", columnIndex: 1 },
code: newCode,
},
},
Expand Down Expand Up @@ -378,7 +376,7 @@ describe("EditNotebookTool", () => {
).rejects.toThrow("Cell not found");
});

it("should throw error when column ID doesn't exist", async () => {
it("should throw error when column index is out of range", async () => {
const notebook = MockNotebook.notebookState({
cellData: {
[cellId1]: { code: "x = 1" },
Expand All @@ -392,15 +390,15 @@ describe("EditNotebookTool", () => {
edit: {
type: "add_cell",
position: {
type: "__end__",
columnId: "nonexistent" as CellColumnId,
type: "end",
columnIndex: -1,
},
code: "y = 2",
},
},
toolContext as never,
),
).rejects.toThrow("Column not found");
).rejects.toThrow("Column index is out of range");
});
});

Expand Down Expand Up @@ -542,7 +540,7 @@ describe("EditNotebookTool", () => {
{
edit: {
type: "add_cell",
position: "__end__",
position: "end",
code: "y = 2",
},
},
Expand Down
41 changes: 22 additions & 19 deletions frontend/src/core/ai/tools/edit-notebook-tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ const description: ToolDescription = {
baseDescription:
"Perform editing operations on the current notebook. You should prefer to create new cells unless you need to edit existing cells. Call this tool multiple times to perform multiple edits. Separate code into logical individual cells to take advantage of the notebook's reactive execution model.",
prerequisites: [
"If you are updating existing cells, you need the cellIds or columnIds. If they are not known, call the lightweight_cell_map_tool to find out.",
"If you are updating existing cells, you need the cellIds. If they are not known, call the lightweight_cell_map_tool to find out.",
],
additionalInfo: `
Args:
edit (object): The editing operation to perform. Must be one of:
- update_cell: Update the code of an existing cell, pass CellId and the new code.
- add_cell: Add a new cell to the notebook. The position of the new cell is specified by the position argument.
Pass "__end__" to add the new cell at the end of the notebook.
Pass "end" to add the new cell at the end of the notebook.
Pass { cellId: cellId, before: true } to add the new cell before the specified cell. And before: false if after the specified cell.
Pass { type: "__end__", columnId: columnId } to add the new cell at the end of the specified column.
Pass { type: "end", columnIndex: number } to add the new cell at the end of a specified column index. The column index is 0-based.
- delete_cell: Delete an existing cell, pass CellId. For deleting cells, the user needs to accept the deletion to actually delete the cell, so you may still see the cell in the notebook on subsequent edits which is fine.

For adding code, use the following guidelines:
Expand All @@ -48,8 +48,8 @@ const description: ToolDescription = {

type CellPosition =
| { cellId: CellId; before: boolean }
| { type: "__end__"; columnId: CellColumnId }
| "__end__";
| { type: "end"; columnIndex: number }
| "end";

const editNotebookSchema = z.object({
edit: z.discriminatedUnion("type", [
Expand All @@ -66,10 +66,10 @@ const editNotebookSchema = z.object({
before: z.boolean(),
}),
z.object({
type: z.literal("__end__"),
columnId: z.string() as unknown as z.ZodType<CellColumnId>,
type: z.literal("end"),
columnIndex: z.number(),
}),
z.literal("__end__"),
z.literal("end"),
]) satisfies z.ZodType<CellPosition>,
code: z.string(),
}),
Expand Down Expand Up @@ -142,9 +142,9 @@ export class EditNotebookTool
this.validateCellIdExists(position.cellId, notebook);
notebookPosition = position.cellId;
before = position.before;
} else if ("columnId" in position) {
this.validateColumnIdExists(position.columnId, notebook);
notebookPosition = { type: "__end__", columnId: position.columnId };
} else if ("columnIndex" in position) {
const columnId = this.getColumnId(position.columnIndex, notebook);
notebookPosition = { type: "__end__", columnId };
}
}

Expand Down Expand Up @@ -202,19 +202,22 @@ export class EditNotebookTool
}
}

private validateColumnIdExists(
columnId: CellColumnId,
private getColumnId(
columnIndex: number,
notebook: NotebookState,
) {
): CellColumnId {
const cellIds = notebook.cellIds;
if (!cellIds.getColumns().some((column) => column.id === columnId)) {
const columns = cellIds.getColumns();

if (columnIndex < 0 || columnIndex >= columns.length) {
throw new ToolExecutionError(
"Column not found",
"COLUMN_NOT_FOUND",
false,
"Check which columns exist in the notebook",
"Column index is out of range",
"COLUMN_INDEX_OUT_OF_RANGE",
true,
"Choose a column index between 0 and the number of columns in the notebook (0-based)",
);
}
return columns[columnIndex].id;
}

private getCellEditorView(
Expand Down
Loading