From 826d1842aa43ca77a593004747301b899280766d Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Thu, 2 Oct 2025 12:05:16 +0800 Subject: [PATCH 1/8] wip --- .../editor/ai/ai-completion-editor.tsx | 19 ++++- .../editor/cell/code/cell-editor.tsx | 1 + .../src/components/editor/errors/auto-fix.tsx | 75 +++++++++++-------- frontend/src/core/ai/state.ts | 9 ++- frontend/src/core/cells/cells.ts | 2 +- frontend/src/core/errors/errors.ts | 42 +++++++++-- 6 files changed, 107 insertions(+), 41 deletions(-) diff --git a/frontend/src/components/editor/ai/ai-completion-editor.tsx b/frontend/src/components/editor/ai/ai-completion-editor.tsx index 35569b597f4..24790e284db 100644 --- a/frontend/src/components/editor/ai/ai-completion-editor.tsx +++ b/frontend/src/components/editor/ai/ai-completion-editor.tsx @@ -45,6 +45,7 @@ interface Props { declineChange: () => void; acceptChange: (rightHandCode: string) => void; enabled: boolean; + initialTrigger?: boolean; /** * Children shown when there is no completion */ @@ -67,9 +68,13 @@ export const AiCompletionEditor: React.FC = ({ declineChange, acceptChange, enabled, + initialTrigger, children, }) => { - const [completionBody, setCompletionBody] = useState({}); + const [hasTriggered, setHasTriggered] = useState(false); + const [completionBody, setCompletionBody] = useState( + initialPrompt ? getAICompletionBody({ input: initialPrompt }) : {}, + ); const [includeOtherCells, setIncludeOtherCells] = useAtom( includeOtherCellsAtom, @@ -140,6 +145,18 @@ export const AiCompletionEditor: React.FC = ({ } }, [enabled, initialPrompt, setInput]); + // TODO: Does not work properly + if (!hasTriggered && initialTrigger) { + setHasTriggered(true); + // Use requestAnimationFrame for better timing + requestAnimationFrame(() => { + if (inputRef.current?.view) { + storePrompt(inputRef.current.view); + } + handleSubmit(); + }); + } + const { theme } = useTheme(); const handleAcceptCompletion = () => { diff --git a/frontend/src/components/editor/cell/code/cell-editor.tsx b/frontend/src/components/editor/cell/code/cell-editor.tsx index c6ef30bbe0d..ec6dd507a46 100644 --- a/frontend/src/components/editor/cell/code/cell-editor.tsx +++ b/frontend/src/components/editor/cell/code/cell-editor.tsx @@ -407,6 +407,7 @@ const CellEditorInternal = ({ { diff --git a/frontend/src/components/editor/errors/auto-fix.tsx b/frontend/src/components/editor/errors/auto-fix.tsx index c5ed56e0789..fe54cdd1218 100644 --- a/frontend/src/components/editor/errors/auto-fix.tsx +++ b/frontend/src/components/editor/errors/auto-fix.tsx @@ -1,7 +1,7 @@ /* Copyright 2024 Marimo. All rights reserved. */ import { useAtomValue, useSetAtom } from "jotai"; -import { WrenchIcon } from "lucide-react"; +import { WrenchIcon, ZapIcon } from "lucide-react"; import { Button } from "@/components/ui/button"; import { Tooltip } from "@/components/ui/tooltip"; import { aiCompletionCellAtom } from "@/core/ai/state"; @@ -37,35 +37,50 @@ export const AutoFixButton = ({ // multiple fixes. const firstFix = autoFixes[0]; + const handleFix = (aiInstantFix = false) => { + const editorView = + store.get(notebookAtom).cellHandles[cellId].current?.editorView; + firstFix.onFix({ + addCodeBelow: (code: string) => { + createNewCell({ + cellId: cellId, + autoFocus: false, + before: false, + code: code, + }); + }, + editor: editorView, + cellId: cellId, + aiFix: { + setAiCompletionCell, + instantFix: aiInstantFix, + }, + }); + // Focus the editor + editorView?.focus(); + }; + return ( - - - +
+ + + + + {firstFix.fixType === "ai" && ( + + + + )} +
); }; diff --git a/frontend/src/core/ai/state.ts b/frontend/src/core/ai/state.ts index 008e2314a60..a23cf83c888 100644 --- a/frontend/src/core/ai/state.ts +++ b/frontend/src/core/ai/state.ts @@ -13,10 +13,13 @@ const KEY = "marimo:ai:chatState:v5"; export type ChatId = TypedString<"ChatId">; -export const aiCompletionCellAtom = atom<{ +export interface AiCompletionCell { cellId: CellId; initialPrompt?: string; -} | null>(null); + triggerImmediately?: boolean; +} + +export const aiCompletionCellAtom = atom(null); const INCLUDE_OTHER_CELLS_KEY = "marimo:ai:includeOtherCells"; export const includeOtherCellsAtom = atomWithStorage( @@ -86,7 +89,7 @@ export const activeChatAtom = atom( } return state.chats.get(state.activeChatId); }, - (get, set, chatId: ChatId | null) => { + (_get, set, chatId: ChatId | null) => { set(chatStateAtom, (prev) => ({ ...prev, activeChatId: chatId, diff --git a/frontend/src/core/cells/cells.ts b/frontend/src/core/cells/cells.ts index 4b7ebc21719..5fd8e220ba1 100644 --- a/frontend/src/core/cells/cells.ts +++ b/frontend/src/core/cells/cells.ts @@ -1615,7 +1615,7 @@ export const columnIdsAtom = atom((get) => get(notebookAtom).cellIds.getColumnIds(), ); -const cellDataAtom = atomFamily((cellId: CellId) => +export const cellDataAtom = atomFamily((cellId: CellId) => atom((get) => get(notebookAtom).cellData[cellId]), ); const cellRuntimeAtom = atomFamily((cellId: CellId) => diff --git a/frontend/src/core/errors/errors.ts b/frontend/src/core/errors/errors.ts index 0cc30d1da88..e1a4ba21166 100644 --- a/frontend/src/core/errors/errors.ts +++ b/frontend/src/core/errors/errors.ts @@ -1,21 +1,29 @@ /* Copyright 2024 Marimo. All rights reserved. */ import type { EditorView } from "@codemirror/view"; import { invariant } from "@/utils/invariant"; +import type { AiCompletionCell } from "../ai/state"; +import { cellDataAtom } from "../cells/cells"; import type { CellId } from "../cells/ids"; +import { LanguageAdapters } from "../codemirror/language/LanguageAdapters"; +import { dataSourceConnectionsAtom } from "../datasets/data-source-connections"; import type { MarimoError } from "../kernel/messages"; +import { store } from "../state/jotai"; import { wrapInFunction } from "./utils"; +interface AIFix { + setAiCompletionCell: (opts: AiCompletionCell) => void; + instantFix: boolean; +} + export interface AutoFix { title: string; description: string; + fixType: "manual" | "ai"; onFix: (ctx: { addCodeBelow: (code: string) => void; editor: EditorView | undefined; cellId: CellId; - setAiCompletionCell?: (cell: { - cellId: CellId; - initialPrompt?: string; - }) => void; + aiFix?: AIFix; }) => Promise; } @@ -31,6 +39,7 @@ export function getAutoFixes( title: "Fix: Wrap in a function", description: "Make this cell's variables local by wrapping the cell in a function.", + fixType: "manual", onFix: async (ctx) => { invariant(ctx.editor, "Editor is null"); const code = wrapInFunction(ctx.editor.state.doc.toString()); @@ -59,6 +68,7 @@ export function getAutoFixes( { title: `Fix: Add '${cellCode}'`, description: "Add a new cell for the missing import", + fixType: "manual", onFix: async (ctx) => { ctx.addCodeBelow(cellCode); }, @@ -75,10 +85,17 @@ export function getAutoFixes( { title: "Fix with AI", description: "Fix the SQL statement", + fixType: "ai", onFix: async (ctx) => { - ctx.setAiCompletionCell?.({ + const datasourceContext = getDatasourceContext(ctx.cellId); + let initialPrompt = `Fix the SQL statement: ${error.msg}.`; + if (datasourceContext) { + initialPrompt += `\nUse the following tables and schema to form your query: ${datasourceContext}`; + } + ctx.aiFix?.setAiCompletionCell({ cellId: ctx.cellId, - initialPrompt: `Fix the SQL statement: ${error.msg}`, + initialPrompt: initialPrompt, + triggerImmediately: ctx.aiFix.instantFix, }); }, }, @@ -122,3 +139,16 @@ const IMPORT_MAPPING: Record = { re: "re", sys: "sys", }; + +function getDatasourceContext(cellId: CellId): string | null { + const cellData = store.get(cellDataAtom(cellId)); + const code = cellData?.code; + const [_sqlStatement, _, metadata] = LanguageAdapters.sql.transformIn(code); + const datasourceSchema = store + .get(dataSourceConnectionsAtom) + .connectionsMap.get(metadata.engine); + if (datasourceSchema) { + return `@datasource://${datasourceSchema.name}`; + } + return null; +} From d124efe08c0b59b2d55a2509520a9f8bca34c87b Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Thu, 2 Oct 2025 12:09:00 +0800 Subject: [PATCH 2/8] move function --- .../core/ai/context/providers/datasource.ts | 22 ++++++++++++++++++- frontend/src/core/errors/errors.ts | 18 +-------------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/frontend/src/core/ai/context/providers/datasource.ts b/frontend/src/core/ai/context/providers/datasource.ts index e37ce323286..ab32848657d 100644 --- a/frontend/src/core/ai/context/providers/datasource.ts +++ b/frontend/src/core/ai/context/providers/datasource.ts @@ -3,10 +3,14 @@ import type { Completion } from "@codemirror/autocomplete"; import { createRoot } from "react-dom/client"; import { dbDisplayName } from "@/components/databases/display"; +import { cellDataAtom } from "@/core/cells/cells"; +import type { CellId } from "@/core/cells/ids"; +import { LanguageAdapters } from "@/core/codemirror/language/LanguageAdapters"; import { renderDatasourceInfo } from "@/core/codemirror/language/languages/sql/renderers"; import { type ConnectionsMap, type DatasetTablesMap, + dataSourceConnectionsAtom, getTableType, } from "@/core/datasets/data-source-connections"; import { @@ -14,6 +18,7 @@ import { INTERNAL_SQL_ENGINES, } from "@/core/datasets/engines"; import type { DataSourceConnection, DataTable } from "@/core/kernel/messages"; +import { store } from "@/core/state/jotai"; import type { AIContextItem } from "../registry"; import { AIContextProvider } from "../registry"; import { contextToXml } from "../utils"; @@ -37,10 +42,12 @@ export interface DatasourceContextItem extends AIContextItem { }; } +const CONTEXT_TYPE = "datasource"; + export class DatasourceContextProvider extends AIContextProvider { readonly title = "Datasource"; readonly mentionPrefix = "@"; - readonly contextType = "datasource"; + readonly contextType = CONTEXT_TYPE; private connectionsMap: ConnectionsMap; private dataframes: DataTable[]; @@ -140,3 +147,16 @@ export class DatasourceContextProvider extends AIContextProvider = { re: "re", sys: "sys", }; - -function getDatasourceContext(cellId: CellId): string | null { - const cellData = store.get(cellDataAtom(cellId)); - const code = cellData?.code; - const [_sqlStatement, _, metadata] = LanguageAdapters.sql.transformIn(code); - const datasourceSchema = store - .get(dataSourceConnectionsAtom) - .connectionsMap.get(metadata.engine); - if (datasourceSchema) { - return `@datasource://${datasourceSchema.name}`; - } - return null; -} From 1629516a322fead80e7b2eb759fc0c14d933a916 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Thu, 2 Oct 2025 13:58:50 +0800 Subject: [PATCH 3/8] fixes --- .../ai/context/providers/__tests__/datasource.test.ts | 10 +++++++++- frontend/src/core/ai/context/providers/datasource.ts | 6 +++++- frontend/src/core/errors/__tests__/errors.test.ts | 1 + 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/frontend/src/core/ai/context/providers/__tests__/datasource.test.ts b/frontend/src/core/ai/context/providers/__tests__/datasource.test.ts index 1521922438e..a3f52fc14db 100644 --- a/frontend/src/core/ai/context/providers/__tests__/datasource.test.ts +++ b/frontend/src/core/ai/context/providers/__tests__/datasource.test.ts @@ -2,6 +2,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { beforeEach, describe, expect, it } from "vitest"; +import type { CellId } from "@/core/cells/ids"; import type { ConnectionsMap, DatasetTablesMap, @@ -9,7 +10,7 @@ import type { import { DUCKDB_ENGINE } from "@/core/datasets/engines"; import type { DataSourceConnection, DataTable } from "@/core/kernel/messages"; import { Boosts, Sections } from "../common"; -import { DatasourceContextProvider } from "../datasource"; +import { DatasourceContextProvider, getDatasourceContext } from "../datasource"; // Mock data for testing const createMockDataSourceConnection = ( @@ -630,3 +631,10 @@ describe("DatasourceContextProvider", () => { }); }); }); + +describe("getDatasourceContext", () => { + it("should return null if no cell ID is found", () => { + const context = getDatasourceContext("1" as CellId); + expect(context).toBeNull(); + }); +}); diff --git a/frontend/src/core/ai/context/providers/datasource.ts b/frontend/src/core/ai/context/providers/datasource.ts index ab32848657d..04d718189da 100644 --- a/frontend/src/core/ai/context/providers/datasource.ts +++ b/frontend/src/core/ai/context/providers/datasource.ts @@ -151,12 +151,16 @@ export class DatasourceContextProvider extends AIContextProvider { expect(fixes).toHaveLength(1); expect(fixes[0].title).toBe("Fix with AI"); + // No fixes without AI expect(getAutoFixes(error, { aiEnabled: false })).toHaveLength(0); }); From 4863b8574383e407a8ad6e5bb967bfd649768872 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Fri, 3 Oct 2025 01:01:36 +0800 Subject: [PATCH 4/8] fix useEffect, and have switch --- .../editor/ai/ai-completion-editor.tsx | 37 +++++++++--------- .../editor/cell/code/cell-editor.tsx | 2 +- .../components/editor/errors/auto-fix-atom.ts | 18 +++++++++ .../src/components/editor/errors/auto-fix.tsx | 38 ++++++++++++++----- 4 files changed, 64 insertions(+), 31 deletions(-) create mode 100644 frontend/src/components/editor/errors/auto-fix-atom.ts diff --git a/frontend/src/components/editor/ai/ai-completion-editor.tsx b/frontend/src/components/editor/ai/ai-completion-editor.tsx index 24790e284db..b283e0dd3de 100644 --- a/frontend/src/components/editor/ai/ai-completion-editor.tsx +++ b/frontend/src/components/editor/ai/ai-completion-editor.tsx @@ -22,6 +22,7 @@ import { getCodes } from "@/core/codemirror/copilot/getCodes"; import type { LanguageAdapterType } from "@/core/codemirror/language/types"; import { selectAllText } from "@/core/codemirror/utils"; import { useRuntimeManager } from "@/core/runtime/config"; +import { useEvent } from "@/hooks/useEvent"; import { useTheme } from "@/theme/useTheme"; import { cn } from "@/utils/cn"; import { prettyError } from "@/utils/errors"; @@ -45,7 +46,7 @@ interface Props { declineChange: () => void; acceptChange: (rightHandCode: string) => void; enabled: boolean; - initialTrigger?: boolean; + triggerImmediately?: boolean; /** * Children shown when there is no completion */ @@ -68,13 +69,10 @@ export const AiCompletionEditor: React.FC = ({ declineChange, acceptChange, enabled, - initialTrigger, + triggerImmediately, children, }) => { - const [hasTriggered, setHasTriggered] = useState(false); - const [completionBody, setCompletionBody] = useState( - initialPrompt ? getAICompletionBody({ input: initialPrompt }) : {}, - ); + const [completionBody, setCompletionBody] = useState({}); const [includeOtherCells, setIncludeOtherCells] = useAtom( includeOtherCellsAtom, @@ -99,7 +97,11 @@ export const AiCompletionEditor: React.FC = ({ // Throttle the messages and data updates to 100ms experimental_throttle: 100, body: { - ...completionBody, + ...(Object.keys(completionBody).length > 0 + ? completionBody + : initialPrompt + ? getAICompletionBody({ input: initialPrompt }) + : {}), includeOtherCode: includeOtherCells ? getCodes(currentCode) : "", code: currentCode, language: currentLanguageAdapter, @@ -119,6 +121,12 @@ export const AiCompletionEditor: React.FC = ({ const inputRef = React.useRef(null); const completion = untrimmedCompletion.trimEnd(); + const initialSubmit = useEvent(() => { + if (triggerImmediately && !isLoading && initialPrompt) { + handleSubmit(); + } + }); + // Focus the input useEffect(() => { if (enabled) { @@ -127,6 +135,7 @@ export const AiCompletionEditor: React.FC = ({ const input = inputRef.current; if (input?.view) { input.view.focus(); + initialSubmit(); return true; } return false; @@ -136,7 +145,7 @@ export const AiCompletionEditor: React.FC = ({ selectAllText(inputRef.current?.view); } - }, [enabled]); + }, [enabled, initialSubmit]); // Reset the input when the prompt changes useEffect(() => { @@ -145,18 +154,6 @@ export const AiCompletionEditor: React.FC = ({ } }, [enabled, initialPrompt, setInput]); - // TODO: Does not work properly - if (!hasTriggered && initialTrigger) { - setHasTriggered(true); - // Use requestAnimationFrame for better timing - requestAnimationFrame(() => { - if (inputRef.current?.view) { - storePrompt(inputRef.current.view); - } - handleSubmit(); - }); - } - const { theme } = useTheme(); const handleAcceptCompletion = () => { diff --git a/frontend/src/components/editor/cell/code/cell-editor.tsx b/frontend/src/components/editor/cell/code/cell-editor.tsx index ec6dd507a46..a422c97be96 100644 --- a/frontend/src/components/editor/cell/code/cell-editor.tsx +++ b/frontend/src/components/editor/cell/code/cell-editor.tsx @@ -407,7 +407,7 @@ const CellEditorInternal = ({ { diff --git a/frontend/src/components/editor/errors/auto-fix-atom.ts b/frontend/src/components/editor/errors/auto-fix-atom.ts new file mode 100644 index 00000000000..25fe2a12508 --- /dev/null +++ b/frontend/src/components/editor/errors/auto-fix-atom.ts @@ -0,0 +1,18 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import { useAtom } from "jotai"; +import { atomWithStorage } from "jotai/utils"; +import { store } from "@/core/state/jotai"; + +const BASE_KEY = "marimo:instant-ai-fix"; + +const instantAIFixAtom = atomWithStorage(BASE_KEY, false); + +export function useInstantAIFix() { + const [instantAIFix, setInstantAIFix] = useAtom(instantAIFixAtom); + return { instantAIFix, setInstantAIFix }; +} + +export function getInstantAIFix() { + return store.get(instantAIFixAtom); +} diff --git a/frontend/src/components/editor/errors/auto-fix.tsx b/frontend/src/components/editor/errors/auto-fix.tsx index fe54cdd1218..8fa9c2bd052 100644 --- a/frontend/src/components/editor/errors/auto-fix.tsx +++ b/frontend/src/components/editor/errors/auto-fix.tsx @@ -1,8 +1,9 @@ /* Copyright 2024 Marimo. All rights reserved. */ import { useAtomValue, useSetAtom } from "jotai"; -import { WrenchIcon, ZapIcon } from "lucide-react"; +import { WrenchIcon, ZapIcon, ZapOffIcon } from "lucide-react"; import { Button } from "@/components/ui/button"; +import { Switch } from "@/components/ui/switch"; import { Tooltip } from "@/components/ui/tooltip"; import { aiCompletionCellAtom } from "@/core/ai/state"; import { notebookAtom, useCellActions } from "@/core/cells/cells"; @@ -12,6 +13,7 @@ import { getAutoFixes } from "@/core/errors/errors"; import type { MarimoError } from "@/core/kernel/messages"; import { store } from "@/core/state/jotai"; import { cn } from "@/utils/cn"; +import { useInstantAIFix } from "./auto-fix-atom"; export const AutoFixButton = ({ errors, @@ -22,6 +24,7 @@ export const AutoFixButton = ({ cellId: CellId; className?: string; }) => { + const { instantAIFix, setInstantAIFix } = useInstantAIFix(); const { createNewCell } = useCellActions(); const aiEnabled = useAtomValue(aiEnabledAtom); const autoFixes = errors.flatMap((error) => @@ -37,7 +40,7 @@ export const AutoFixButton = ({ // multiple fixes. const firstFix = autoFixes[0]; - const handleFix = (aiInstantFix = false) => { + const handleFix = () => { const editorView = store.get(notebookAtom).cellHandles[cellId].current?.editorView; firstFix.onFix({ @@ -53,7 +56,7 @@ export const AutoFixButton = ({ cellId: cellId, aiFix: { setAiCompletionCell, - instantFix: aiInstantFix, + instantFix: instantAIFix, }, }); // Focus the editor @@ -61,13 +64,13 @@ export const AutoFixButton = ({ }; return ( -
+
- +
+ setInstantAIFix(!instantAIFix)} + size="sm" + className="h-4 w-8" + title="Toggle instant AI fix mode" + /> + + {instantAIFix ? ( + + ) : ( + + )} + +
)}
); From 57c3414ac53dffe59b0fb829cb0cb4630f88a47d Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Fri, 3 Oct 2025 01:08:17 +0800 Subject: [PATCH 5/8] rename file --- frontend/src/components/editor/errors/auto-fix.tsx | 2 +- .../editor/errors/{auto-fix-atom.ts => instant-fix.ts} | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) rename frontend/src/components/editor/errors/{auto-fix-atom.ts => instant-fix.ts} (89%) diff --git a/frontend/src/components/editor/errors/auto-fix.tsx b/frontend/src/components/editor/errors/auto-fix.tsx index 8fa9c2bd052..c7a14167a69 100644 --- a/frontend/src/components/editor/errors/auto-fix.tsx +++ b/frontend/src/components/editor/errors/auto-fix.tsx @@ -13,7 +13,7 @@ import { getAutoFixes } from "@/core/errors/errors"; import type { MarimoError } from "@/core/kernel/messages"; import { store } from "@/core/state/jotai"; import { cn } from "@/utils/cn"; -import { useInstantAIFix } from "./auto-fix-atom"; +import { useInstantAIFix } from "./instant-fix"; export const AutoFixButton = ({ errors, diff --git a/frontend/src/components/editor/errors/auto-fix-atom.ts b/frontend/src/components/editor/errors/instant-fix.ts similarity index 89% rename from frontend/src/components/editor/errors/auto-fix-atom.ts rename to frontend/src/components/editor/errors/instant-fix.ts index 25fe2a12508..30fac9f9480 100644 --- a/frontend/src/components/editor/errors/auto-fix-atom.ts +++ b/frontend/src/components/editor/errors/instant-fix.ts @@ -4,6 +4,7 @@ import { useAtom } from "jotai"; import { atomWithStorage } from "jotai/utils"; import { store } from "@/core/state/jotai"; +// If true, AI will immediately fix errors where possible const BASE_KEY = "marimo:instant-ai-fix"; const instantAIFixAtom = atomWithStorage(BASE_KEY, false); From a79ffd07212c30e7f378081dec9215e12f80f9e6 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Fri, 3 Oct 2025 01:19:11 +0800 Subject: [PATCH 6/8] better prompt --- frontend/src/core/errors/errors.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/src/core/errors/errors.ts b/frontend/src/core/errors/errors.ts index e84e4e6b206..4d029bae200 100644 --- a/frontend/src/core/errors/errors.ts +++ b/frontend/src/core/errors/errors.ts @@ -87,7 +87,7 @@ export function getAutoFixes( const datasourceContext = getDatasourceContext(ctx.cellId); let initialPrompt = `Fix the SQL statement: ${error.msg}.`; if (datasourceContext) { - initialPrompt += `\nUse the following tables and schema to form your query: ${datasourceContext}`; + initialPrompt += `\nDatabase schema: ${datasourceContext}`; } ctx.aiFix?.setAiCompletionCell({ cellId: ctx.cellId, From 5b1dd1cef45c4afb7d117c8d765e8ea36c7780b8 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Fri, 3 Oct 2025 12:20:00 +0800 Subject: [PATCH 7/8] change to fix mode --- .../src/components/editor/errors/auto-fix.tsx | 145 +++++++++++++----- .../src/components/editor/errors/fix-mode.ts | 13 ++ .../components/editor/errors/instant-fix.ts | 19 --- frontend/src/core/errors/errors.ts | 4 +- 4 files changed, 124 insertions(+), 57 deletions(-) create mode 100644 frontend/src/components/editor/errors/fix-mode.ts delete mode 100644 frontend/src/components/editor/errors/instant-fix.ts diff --git a/frontend/src/components/editor/errors/auto-fix.tsx b/frontend/src/components/editor/errors/auto-fix.tsx index c7a14167a69..761311d1d43 100644 --- a/frontend/src/components/editor/errors/auto-fix.tsx +++ b/frontend/src/components/editor/errors/auto-fix.tsx @@ -1,9 +1,14 @@ /* Copyright 2024 Marimo. All rights reserved. */ -import { useAtomValue, useSetAtom } from "jotai"; -import { WrenchIcon, ZapIcon, ZapOffIcon } from "lucide-react"; +import { useAtomValue, useSetAtom, useStore } from "jotai"; +import { ChevronDownIcon, SparklesIcon, WrenchIcon } from "lucide-react"; import { Button } from "@/components/ui/button"; -import { Switch } from "@/components/ui/switch"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; import { Tooltip } from "@/components/ui/tooltip"; import { aiCompletionCellAtom } from "@/core/ai/state"; import { notebookAtom, useCellActions } from "@/core/cells/cells"; @@ -11,9 +16,8 @@ import type { CellId } from "@/core/cells/ids"; import { aiEnabledAtom } from "@/core/config/config"; import { getAutoFixes } from "@/core/errors/errors"; import type { MarimoError } from "@/core/kernel/messages"; -import { store } from "@/core/state/jotai"; import { cn } from "@/utils/cn"; -import { useInstantAIFix } from "./instant-fix"; +import { useFixMode } from "./fix-mode"; export const AutoFixButton = ({ errors, @@ -24,7 +28,7 @@ export const AutoFixButton = ({ cellId: CellId; className?: string; }) => { - const { instantAIFix, setInstantAIFix } = useInstantAIFix(); + const store = useStore(); const { createNewCell } = useCellActions(); const aiEnabled = useAtomValue(aiEnabledAtom); const autoFixes = errors.flatMap((error) => @@ -40,7 +44,7 @@ export const AutoFixButton = ({ // multiple fixes. const firstFix = autoFixes[0]; - const handleFix = () => { + const handleFix = (triggerFix: boolean) => { const editorView = store.get(notebookAtom).cellHandles[cellId].current?.editorView; firstFix.onFix({ @@ -56,7 +60,7 @@ export const AutoFixButton = ({ cellId: cellId, aiFix: { setAiCompletionCell, - instantFix: instantAIFix, + triggerFix, }, }); // Focus the editor @@ -64,41 +68,110 @@ export const AutoFixButton = ({ }; return ( -
- +
+ {firstFix.fixType === "ai" ? ( + handleFix(false)} + applyAutofix={() => handleFix(true)} + /> + ) : ( + + + + )} +
+ ); +}; + +const PromptIcon = SparklesIcon; +const AutofixIcon = WrenchIcon; + +const PromptTitle = "Suggest a prompt"; +const AutofixTitle = "Fix with AI"; + +export const AIFixButton = ({ + tooltip, + openPrompt, + applyAutofix, +}: { + tooltip: string; + openPrompt: () => void; + applyAutofix: () => void; +}) => { + const { fixMode, setFixMode } = useFixMode(); + + return ( +
+ - - {firstFix.fixType === "ai" && ( -
- setInstantAIFix(!instantAIFix)} - size="sm" - className="h-4 w-8" - title="Toggle instant AI fix mode" - /> - + +
- )} + + + + + { + setFixMode(fixMode === "prompt" ? "autofix" : "prompt"); + }} + > + + + + +
+ ); +}; + +const AiModeItem = ({ mode }: { mode: "prompt" | "autofix" }) => { + const icon = + mode === "prompt" ? ( + + ) : ( + + ); + const title = mode === "prompt" ? PromptTitle : AutofixTitle; + const description = + mode === "prompt" + ? "Edit the prompt before applying" + : "Apply AI fixes automatically"; + + return ( +
+ {icon} +
+ {title} + {description} +
); }; diff --git a/frontend/src/components/editor/errors/fix-mode.ts b/frontend/src/components/editor/errors/fix-mode.ts new file mode 100644 index 00000000000..6a03948d080 --- /dev/null +++ b/frontend/src/components/editor/errors/fix-mode.ts @@ -0,0 +1,13 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import { useAtom } from "jotai"; +import { atomWithStorage } from "jotai/utils"; + +const BASE_KEY = "marimo:ai-autofix-mode"; + +const fixModeAtom = atomWithStorage<"prompt" | "autofix">(BASE_KEY, "autofix"); + +export function useFixMode() { + const [fixMode, setFixMode] = useAtom(fixModeAtom); + return { fixMode, setFixMode }; +} diff --git a/frontend/src/components/editor/errors/instant-fix.ts b/frontend/src/components/editor/errors/instant-fix.ts deleted file mode 100644 index 30fac9f9480..00000000000 --- a/frontend/src/components/editor/errors/instant-fix.ts +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2024 Marimo. All rights reserved. */ - -import { useAtom } from "jotai"; -import { atomWithStorage } from "jotai/utils"; -import { store } from "@/core/state/jotai"; - -// If true, AI will immediately fix errors where possible -const BASE_KEY = "marimo:instant-ai-fix"; - -const instantAIFixAtom = atomWithStorage(BASE_KEY, false); - -export function useInstantAIFix() { - const [instantAIFix, setInstantAIFix] = useAtom(instantAIFixAtom); - return { instantAIFix, setInstantAIFix }; -} - -export function getInstantAIFix() { - return store.get(instantAIFixAtom); -} diff --git a/frontend/src/core/errors/errors.ts b/frontend/src/core/errors/errors.ts index 4d029bae200..4e7a67df404 100644 --- a/frontend/src/core/errors/errors.ts +++ b/frontend/src/core/errors/errors.ts @@ -9,7 +9,7 @@ import { wrapInFunction } from "./utils"; interface AIFix { setAiCompletionCell: (opts: AiCompletionCell) => void; - instantFix: boolean; + triggerFix: boolean; } export interface AutoFix { @@ -92,7 +92,7 @@ export function getAutoFixes( ctx.aiFix?.setAiCompletionCell({ cellId: ctx.cellId, initialPrompt: initialPrompt, - triggerImmediately: ctx.aiFix.instantFix, + triggerImmediately: ctx.aiFix.triggerFix, }); }, }, From d350f75282bb6e461480b5da3cb0e991f30e36f5 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Fri, 3 Oct 2025 12:29:58 +0800 Subject: [PATCH 8/8] type update --- frontend/src/components/editor/errors/auto-fix.tsx | 4 ++-- frontend/src/components/editor/errors/fix-mode.ts | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/frontend/src/components/editor/errors/auto-fix.tsx b/frontend/src/components/editor/errors/auto-fix.tsx index 761311d1d43..d1c65148994 100644 --- a/frontend/src/components/editor/errors/auto-fix.tsx +++ b/frontend/src/components/editor/errors/auto-fix.tsx @@ -17,7 +17,7 @@ import { aiEnabledAtom } from "@/core/config/config"; import { getAutoFixes } from "@/core/errors/errors"; import type { MarimoError } from "@/core/kernel/messages"; import { cn } from "@/utils/cn"; -import { useFixMode } from "./fix-mode"; +import { type FixMode, useFixMode } from "./fix-mode"; export const AutoFixButton = ({ errors, @@ -152,7 +152,7 @@ export const AIFixButton = ({ ); }; -const AiModeItem = ({ mode }: { mode: "prompt" | "autofix" }) => { +const AiModeItem = ({ mode }: { mode: FixMode }) => { const icon = mode === "prompt" ? ( diff --git a/frontend/src/components/editor/errors/fix-mode.ts b/frontend/src/components/editor/errors/fix-mode.ts index 6a03948d080..8adf10d538b 100644 --- a/frontend/src/components/editor/errors/fix-mode.ts +++ b/frontend/src/components/editor/errors/fix-mode.ts @@ -3,9 +3,11 @@ import { useAtom } from "jotai"; import { atomWithStorage } from "jotai/utils"; +export type FixMode = "prompt" | "autofix"; + const BASE_KEY = "marimo:ai-autofix-mode"; -const fixModeAtom = atomWithStorage<"prompt" | "autofix">(BASE_KEY, "autofix"); +const fixModeAtom = atomWithStorage(BASE_KEY, "autofix"); export function useFixMode() { const [fixMode, setFixMode] = useAtom(fixModeAtom);