Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions frontend/src/components/editor/ai/ai-completion-editor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import { getCodes } from "@/core/codemirror/copilot/getCodes";
import type { LanguageAdapterType } from "@/core/codemirror/language/types";
import { selectAllText } from "@/core/codemirror/utils";
import { useRuntimeManager } from "@/core/runtime/config";
import { useEvent } from "@/hooks/useEvent";
import { useTheme } from "@/theme/useTheme";
import { cn } from "@/utils/cn";
import { prettyError } from "@/utils/errors";
Expand All @@ -45,6 +46,7 @@ interface Props {
declineChange: () => void;
acceptChange: (rightHandCode: string) => void;
enabled: boolean;
triggerImmediately?: boolean;
/**
* Children shown when there is no completion
*/
Expand All @@ -67,6 +69,7 @@ export const AiCompletionEditor: React.FC<Props> = ({
declineChange,
acceptChange,
enabled,
triggerImmediately,
children,
}) => {
const [completionBody, setCompletionBody] = useState<object>({});
Expand Down Expand Up @@ -94,7 +97,11 @@ export const AiCompletionEditor: React.FC<Props> = ({
// Throttle the messages and data updates to 100ms
experimental_throttle: 100,
body: {
...completionBody,
...(Object.keys(completionBody).length > 0
? completionBody
: initialPrompt
? getAICompletionBody({ input: initialPrompt })
: {}),
includeOtherCode: includeOtherCells ? getCodes(currentCode) : "",
code: currentCode,
language: currentLanguageAdapter,
Expand All @@ -114,6 +121,12 @@ export const AiCompletionEditor: React.FC<Props> = ({
const inputRef = React.useRef<ReactCodeMirrorRef>(null);
const completion = untrimmedCompletion.trimEnd();

const initialSubmit = useEvent(() => {
if (triggerImmediately && !isLoading && initialPrompt) {
handleSubmit();
}
});

// Focus the input
useEffect(() => {
if (enabled) {
Expand All @@ -122,6 +135,7 @@ export const AiCompletionEditor: React.FC<Props> = ({
const input = inputRef.current;
if (input?.view) {
input.view.focus();
initialSubmit();
return true;
}
return false;
Expand All @@ -131,7 +145,7 @@ export const AiCompletionEditor: React.FC<Props> = ({

selectAllText(inputRef.current?.view);
}
}, [enabled]);
}, [enabled, initialSubmit]);

// Reset the input when the prompt changes
useEffect(() => {
Expand Down
1 change: 1 addition & 0 deletions frontend/src/components/editor/cell/code/cell-editor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@
);

return extensions;
}, [

Check warning on line 231 in frontend/src/components/editor/cell/code/cell-editor.tsx

View workflow job for this annotation

GitHub Actions / 🧹 Lint frontend

React Hook useMemo has a missing dependency: 'userConfig.ai?.inline_tooltip'. Either include it or remove the dependency array
cellId,
userConfig.keymap,
userConfig.completion,
Expand Down Expand Up @@ -381,7 +381,7 @@
]);

// Destroy the editor when the component is unmounted
useEffect(() => {

Check warning on line 384 in frontend/src/components/editor/cell/code/cell-editor.tsx

View workflow job for this annotation

GitHub Actions / 🧹 Lint frontend

This effect only uses props. Consider lifting the logic up to the parent
const ev = editorViewRef.current;
return () => {
ev?.destroy();
Expand All @@ -407,6 +407,7 @@
<AiCompletionEditor
enabled={aiCompletionCell?.cellId === cellId}
initialPrompt={aiCompletionCell?.initialPrompt}
triggerImmediately={aiCompletionCell?.triggerImmediately}
currentCode={editorViewRef.current?.state.doc.toString() ?? code}
currentLanguageAdapter={languageAdapter}
declineChange={useEvent(() => {
Expand Down
170 changes: 138 additions & 32 deletions frontend/src/components/editor/errors/auto-fix.tsx
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
/* Copyright 2024 Marimo. All rights reserved. */

import { useAtomValue, useSetAtom } from "jotai";
import { WrenchIcon } from "lucide-react";
import { useAtomValue, useSetAtom, useStore } from "jotai";
import { ChevronDownIcon, SparklesIcon, WrenchIcon } from "lucide-react";
import { Button } from "@/components/ui/button";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { Tooltip } from "@/components/ui/tooltip";
import { aiCompletionCellAtom } from "@/core/ai/state";
import { notebookAtom, useCellActions } from "@/core/cells/cells";
import type { CellId } from "@/core/cells/ids";
import { aiEnabledAtom } from "@/core/config/config";
import { getAutoFixes } from "@/core/errors/errors";
import type { MarimoError } from "@/core/kernel/messages";
import { store } from "@/core/state/jotai";
import { cn } from "@/utils/cn";
import { type FixMode, useFixMode } from "./fix-mode";

export const AutoFixButton = ({
errors,
Expand All @@ -22,6 +28,7 @@ export const AutoFixButton = ({
cellId: CellId;
className?: string;
}) => {
const store = useStore();
const { createNewCell } = useCellActions();
const aiEnabled = useAtomValue(aiEnabledAtom);
const autoFixes = errors.flatMap((error) =>
Expand All @@ -37,35 +44,134 @@ export const AutoFixButton = ({
// multiple fixes.
const firstFix = autoFixes[0];

const handleFix = (triggerFix: boolean) => {
const editorView =
store.get(notebookAtom).cellHandles[cellId].current?.editorView;
firstFix.onFix({
addCodeBelow: (code: string) => {
createNewCell({
cellId: cellId,
autoFocus: false,
before: false,
code: code,
});
},
editor: editorView,
cellId: cellId,
aiFix: {
setAiCompletionCell,
triggerFix,
},
});
// Focus the editor
editorView?.focus();
};

return (
<div className={cn("my-2", className)}>
{firstFix.fixType === "ai" ? (
<AIFixButton
tooltip={firstFix.description}
openPrompt={() => handleFix(false)}
applyAutofix={() => handleFix(true)}
/>
) : (
<Tooltip content={firstFix.description} align="start">
<Button
size="xs"
variant="outline"
className="font-normal"
onClick={() => handleFix(false)}
>
<WrenchIcon className="h-3 w-3 mr-2" />
{firstFix.title}
</Button>
</Tooltip>
)}
</div>
);
};

const PromptIcon = SparklesIcon;
const AutofixIcon = WrenchIcon;

const PromptTitle = "Suggest a prompt";
const AutofixTitle = "Fix with AI";

export const AIFixButton = ({
tooltip,
openPrompt,
applyAutofix,
}: {
tooltip: string;
openPrompt: () => void;
applyAutofix: () => void;
}) => {
const { fixMode, setFixMode } = useFixMode();

return (
<div className="flex">
<Tooltip content={tooltip} align="start">
<Button
size="xs"
variant="outline"
className="font-normal rounded-r-none border-r-0"
onClick={fixMode === "prompt" ? openPrompt : applyAutofix}
>
{fixMode === "prompt" ? (
<PromptIcon className="h-3 w-3 mr-2 mb-0.5" />
) : (
<AutofixIcon className="h-3 w-3 mr-2 mb-0.5" />
)}
{fixMode === "prompt" ? PromptTitle : AutofixTitle}
</Button>
</Tooltip>
<DropdownMenu>
<DropdownMenuTrigger asChild={true}>
<Button
size="xs"
variant="outline"
className="rounded-l-none px-2"
aria-label="Fix options"
>
<ChevronDownIcon className="h-3 w-3" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end" className="w-56">
<DropdownMenuItem
className="flex items-center gap-2"
onClick={() => {
setFixMode(fixMode === "prompt" ? "autofix" : "prompt");
}}
>
<AiModeItem mode={fixMode === "prompt" ? "autofix" : "prompt"} />
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</div>
);
};

const AiModeItem = ({ mode }: { mode: FixMode }) => {
const icon =
mode === "prompt" ? (
<PromptIcon className="h-4 w-4" />
) : (
<AutofixIcon className="h-4 w-4" />
);
const title = mode === "prompt" ? PromptTitle : AutofixTitle;
const description =
mode === "prompt"
? "Edit the prompt before applying"
: "Apply AI fixes automatically";

return (
<Tooltip content={firstFix.description} align="start">
<Button
size="xs"
variant="outline"
className={cn("my-2 font-normal", className)}
onClick={() => {
const editorView =
store.get(notebookAtom).cellHandles[cellId].current?.editorView;
firstFix.onFix({
addCodeBelow: (code: string) => {
createNewCell({
cellId: cellId,
autoFocus: false,
before: false,
code: code,
});
},
editor: editorView,
cellId: cellId,
setAiCompletionCell,
});
// Focus the editor
editorView?.focus();
}}
>
<WrenchIcon className="h-3 w-3 mr-2" />
{firstFix.title}
</Button>
</Tooltip>
<div className="flex items-center gap-2">
{icon}
<div className="flex flex-col">
<span className="font-medium">{title}</span>
<span className="text-xs text-muted-foreground">{description}</span>
</div>
</div>
);
};
15 changes: 15 additions & 0 deletions frontend/src/components/editor/errors/fix-mode.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/* Copyright 2024 Marimo. All rights reserved. */

import { useAtom } from "jotai";
import { atomWithStorage } from "jotai/utils";

export type FixMode = "prompt" | "autofix";

const BASE_KEY = "marimo:ai-autofix-mode";

const fixModeAtom = atomWithStorage<FixMode>(BASE_KEY, "autofix");

export function useFixMode() {
const [fixMode, setFixMode] = useAtom(fixModeAtom);
return { fixMode, setFixMode };
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import { beforeEach, describe, expect, it } from "vitest";
import type { CellId } from "@/core/cells/ids";
import type {
ConnectionsMap,
DatasetTablesMap,
} from "@/core/datasets/data-source-connections";
import { DUCKDB_ENGINE } from "@/core/datasets/engines";
import type { DataSourceConnection, DataTable } from "@/core/kernel/messages";
import { Boosts, Sections } from "../common";
import { DatasourceContextProvider } from "../datasource";
import { DatasourceContextProvider, getDatasourceContext } from "../datasource";

// Mock data for testing
const createMockDataSourceConnection = (
Expand Down Expand Up @@ -630,3 +631,10 @@ describe("DatasourceContextProvider", () => {
});
});
});

describe("getDatasourceContext", () => {
it("should return null if no cell ID is found", () => {
const context = getDatasourceContext("1" as CellId);
expect(context).toBeNull();
});
});
26 changes: 25 additions & 1 deletion frontend/src/core/ai/context/providers/datasource.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@
import type { Completion } from "@codemirror/autocomplete";
import { createRoot } from "react-dom/client";
import { dbDisplayName } from "@/components/databases/display";
import { cellDataAtom } from "@/core/cells/cells";
import type { CellId } from "@/core/cells/ids";
import { LanguageAdapters } from "@/core/codemirror/language/LanguageAdapters";
import { renderDatasourceInfo } from "@/core/codemirror/language/languages/sql/renderers";
import {
type ConnectionsMap,
type DatasetTablesMap,
dataSourceConnectionsAtom,
getTableType,
} from "@/core/datasets/data-source-connections";
import {
type ConnectionName,
INTERNAL_SQL_ENGINES,
} from "@/core/datasets/engines";
import type { DataSourceConnection, DataTable } from "@/core/kernel/messages";
import { store } from "@/core/state/jotai";
import type { AIContextItem } from "../registry";
import { AIContextProvider } from "../registry";
import { contextToXml } from "../utils";
Expand All @@ -37,10 +42,12 @@ export interface DatasourceContextItem extends AIContextItem {
};
}

const CONTEXT_TYPE = "datasource";

export class DatasourceContextProvider extends AIContextProvider<DatasourceContextItem> {
readonly title = "Datasource";
readonly mentionPrefix = "@";
readonly contextType = "datasource";
readonly contextType = CONTEXT_TYPE;
private connectionsMap: ConnectionsMap;
private dataframes: DataTable[];

Expand Down Expand Up @@ -140,3 +147,20 @@ export class DatasourceContextProvider extends AIContextProvider<DatasourceConte
};
}
}

export function getDatasourceContext(cellId: CellId): string | null {
const cellData = store.get(cellDataAtom(cellId));
const code = cellData?.code;
if (!code || code.trim() === "") {
return null;
}

const [_sqlStatement, _, metadata] = LanguageAdapters.sql.transformIn(code);
const datasourceSchema = store
.get(dataSourceConnectionsAtom)
.connectionsMap.get(metadata.engine);
if (datasourceSchema) {
return `@${CONTEXT_TYPE}://${datasourceSchema.name}`;
}
return null;
}
Loading
Loading