diff --git a/frontend/src/core/cells/__tests__/add-missing-import.test.ts b/frontend/src/core/cells/__tests__/add-missing-import.test.ts index 80f330e6b0e..d85028ce48a 100644 --- a/frontend/src/core/cells/__tests__/add-missing-import.test.ts +++ b/frontend/src/core/cells/__tests__/add-missing-import.test.ts @@ -1,18 +1,35 @@ /* Copyright 2024 Marimo. All rights reserved. */ import { createStore } from "jotai"; -import { describe, expect, it, vi } from "vitest"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { MockNotebook } from "@/__mocks__/notebook"; +import { MockRequestClient } from "@/__mocks__/requests"; +import { store } from "@/core/state/jotai"; import { variablesAtom } from "@/core/variables/state"; import type { Variables } from "@/core/variables/types"; -import { CollapsibleTree, MultiColumn } from "@/utils/id-tree"; -import { maybeAddMissingImport } from "../add-missing-import"; -import { type NotebookState, notebookAtom } from "../cells"; +import { + maybeAddAltairImport, + maybeAddMarimoImport, + maybeAddMissingImport, +} from "../add-missing-import"; +import { notebookAtom } from "../cells"; import type { CellId } from "../ids"; -import type { CellData } from "../types"; -const CELL_IDS = new MultiColumn([CollapsibleTree.from(["1" as CellId])]); +// Mock the getRequestClient function +const mockRequestClient = MockRequestClient.create(); +vi.mock("@/core/network/requests", () => ({ + getRequestClient: () => mockRequestClient, +})); + +const Cell1 = "1" as CellId; +const Cell2 = "2" as CellId; describe("maybeAddMissingImport", () => { + beforeEach(() => { + store.set(variablesAtom, {} as Variables); + store.set(notebookAtom, MockNotebook.notebookState({ cellData: {} })); + }); + it("should not add an import if the variable is already in the variables state", () => { const appStore = createStore(); appStore.set(variablesAtom, { mo: {} } as Variables); @@ -36,14 +53,14 @@ describe("maybeAddMissingImport", () => { (code) => { const appStore = createStore(); appStore.set(variablesAtom, {} as Variables); - appStore.set(notebookAtom, { - cellData: { - ["1" as CellId]: { - code: code, - } as CellData, - }, - cellIds: CELL_IDS, - } as NotebookState); + appStore.set( + notebookAtom, + MockNotebook.notebookState({ + cellData: { + [Cell1]: { code: code }, + }, + }), + ); const onAddImport = vi.fn(); maybeAddMissingImport({ moduleName: "marimo", @@ -58,14 +75,14 @@ describe("maybeAddMissingImport", () => { it("should add an import if the variable is not in the variables state and the import statement does not exist in the notebook", () => { const appStore = createStore(); appStore.set(variablesAtom, {} as Variables); - appStore.set(notebookAtom, { - cellData: { - ["1" as CellId]: { - code: "mo.md('hello')", - } as CellData, - }, - cellIds: CELL_IDS, - } as NotebookState); + appStore.set( + notebookAtom, + MockNotebook.notebookState({ + cellData: { + [Cell1]: { code: "mo.md('hello')" }, + }, + }), + ); const onAddImport = vi.fn(); maybeAddMissingImport({ moduleName: "marimo", @@ -75,4 +92,32 @@ describe("maybeAddMissingImport", () => { }); expect(onAddImport).toHaveBeenCalledWith("import marimo as mo"); }); + + it("should not create a new cell if import already exists due to skipIfCodeExists", () => { + const addImports = [maybeAddMarimoImport, maybeAddAltairImport]; + for (const addImport of addImports) { + store.set(variablesAtom, {} as Variables); + store.set( + notebookAtom, + MockNotebook.notebookState({ + cellData: { + [Cell1]: { code: "import marimo as mo" }, + [Cell2]: { code: "import altair as alt" }, + }, + }), + ); + + const createNewCell = vi.fn(); + const result = addImport({ + autoInstantiate: false, + createNewCell, + fromCellId: Cell1, + before: false, + }); + + // Should not create a new cell since the import already exists + expect(createNewCell).not.toHaveBeenCalled(); + expect(result).toBeNull(); + } + }); }); diff --git a/frontend/src/core/cells/add-missing-import.ts b/frontend/src/core/cells/add-missing-import.ts index e1a2e2d5955..dbaa638523a 100644 --- a/frontend/src/core/cells/add-missing-import.ts +++ b/frontend/src/core/cells/add-missing-import.ts @@ -80,6 +80,7 @@ export function maybeAddMarimoImport({ code: importStatement, lastCodeRun: autoInstantiate ? importStatement : undefined, newCellId: newCellId, + skipIfCodeExists: true, autoFocus: false, }); if (autoInstantiate) { @@ -101,19 +102,21 @@ export function maybeAddAltairImport({ autoInstantiate: boolean; createNewCell: CellActions["createNewCell"]; fromCellId?: CellId | null; -}): boolean { +}): CellId | null { const client = getRequestClient(); - return maybeAddMissingImport({ + let newCellId: CellId | null = null; + const added = maybeAddMissingImport({ moduleName: "altair", variableName: "alt", onAddImport: (importStatement) => { - const newCellId = CellId.create(); + newCellId = CellId.create(); createNewCell({ cellId: fromCellId ?? "__end__", before: false, code: importStatement, lastCodeRun: autoInstantiate ? importStatement : undefined, newCellId: newCellId, + skipIfCodeExists: true, autoFocus: false, }); if (autoInstantiate) { @@ -124,4 +127,5 @@ export function maybeAddAltairImport({ } }, }); + return added ? newCellId : null; }