Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions frontend/src/components/editor/errors/auto-fix.tsx
Original file line number Diff line number Diff line change
@@ -1,24 +1,33 @@
/* Copyright 2024 Marimo. All rights reserved. */

import { useAtomValue, useSetAtom } from "jotai";
import { WrenchIcon } from "lucide-react";
import { Button } from "@/components/ui/button";
import { Tooltip } from "@/components/ui/tooltip";
import { aiCompletionCellAtom } from "@/core/ai/state";
import { notebookAtom, useCellActions } from "@/core/cells/cells";
import type { CellId } from "@/core/cells/ids";
import { aiEnabledAtom } from "@/core/config/config";
import { getAutoFixes } from "@/core/errors/errors";
import type { MarimoError } from "@/core/kernel/messages";
import { store } from "@/core/state/jotai";
import { cn } from "@/utils/cn";

export const AutoFixButton = ({
errors,
cellId,
className,
}: {
errors: MarimoError[];
cellId: CellId;
className?: string;
}) => {
const { createNewCell } = useCellActions();
const autoFixes = errors.flatMap((error) => getAutoFixes(error));
const aiEnabled = useAtomValue(aiEnabledAtom);
const autoFixes = errors.flatMap((error) =>
getAutoFixes(error, { aiEnabled }),
);
const setAiCompletionCell = useSetAtom(aiCompletionCellAtom);

if (autoFixes.length === 0) {
return null;
Expand All @@ -33,7 +42,7 @@ export const AutoFixButton = ({
<Button
size="xs"
variant="outline"
className="my-2 font-normal"
className={cn("my-2 font-normal", className)}
onClick={() => {
const editorView =
store.get(notebookAtom).cellHandles[cellId].current?.editorView;
Expand All @@ -48,6 +57,7 @@ export const AutoFixButton = ({
},
editor: editorView,
cellId: cellId,
setAiCompletionCell,
});
// Focus the editor
editorView?.focus();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,13 @@ export const MarimoErrorOutput = ({
</div>
);
})}
{cellId && <AutoFixButton errors={sqlErrors} cellId={cellId} />}
{cellId && (
<AutoFixButton
errors={sqlErrors}
cellId={cellId}
className="mt-2.5"
/>
)}
</div>,
);
}
Expand Down
26 changes: 22 additions & 4 deletions frontend/src/core/errors/__tests__/errors.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ describe("getImportCode", () => {
});
});

const opts = {
aiEnabled: true,
};

describe("getAutoFixes", () => {
it("returns wrap in function fix for multiple-defs error", () => {
const error: MarimoError = {
Expand All @@ -24,7 +28,7 @@ describe("getAutoFixes", () => {
cells: ["foo"],
};

const fixes = getAutoFixes(error);
const fixes = getAutoFixes(error, opts);
expect(fixes).toHaveLength(1);
expect(fixes[0].title).toBe("Fix: Wrap in a function");
});
Expand All @@ -37,11 +41,25 @@ describe("getAutoFixes", () => {
raising_cell: null,
};

const fixes = getAutoFixes(error);
const fixes = getAutoFixes(error, opts);
expect(fixes).toHaveLength(1);
expect(fixes[0].title).toBe("Fix: Add 'import numpy as np'");
});

it("returns sql fix for sql-error error", () => {
const error: MarimoError = {
type: "sql-error",
msg: "syntax error",
sql_statement: "SELECT * FROM table",
};

const fixes = getAutoFixes(error, opts);
expect(fixes).toHaveLength(1);
expect(fixes[0].title).toBe("Fix with AI");

expect(getAutoFixes(error, { aiEnabled: false })).toHaveLength(0);
});

it("returns no fixes for NameError with unknown import", () => {
const error: MarimoError = {
type: "exception",
Expand All @@ -50,7 +68,7 @@ describe("getAutoFixes", () => {
raising_cell: null,
};

expect(getAutoFixes(error)).toHaveLength(0);
expect(getAutoFixes(error, opts)).toHaveLength(0);
});

it("returns no fixes for other error types", () => {
Expand All @@ -59,6 +77,6 @@ describe("getAutoFixes", () => {
msg: "invalid syntax",
};

expect(getAutoFixes(error)).toHaveLength(0);
expect(getAutoFixes(error, opts)).toHaveLength(0);
});
});
30 changes: 29 additions & 1 deletion frontend/src/core/errors/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,19 @@ export interface AutoFix {
addCodeBelow: (code: string) => void;
editor: EditorView | undefined;
cellId: CellId;
setAiCompletionCell?: (cell: {
cellId: CellId;
initialPrompt?: string;
}) => void;
}) => Promise<void>;
}

export function getAutoFixes(error: MarimoError): AutoFix[] {
export function getAutoFixes(
error: MarimoError,
opts: {
aiEnabled: boolean;
},
): AutoFix[] {
if (error.type === "multiple-defs") {
return [
{
Expand Down Expand Up @@ -57,6 +66,25 @@ export function getAutoFixes(error: MarimoError): AutoFix[] {
];
}

if (error.type === "sql-error") {
// Only show AI fix if AI is enabled
if (!opts.aiEnabled) {
return [];
}
return [
{
title: "Fix with AI",
description: "Fix the SQL statement",
onFix: async (ctx) => {
ctx.setAiCompletionCell?.({
cellId: ctx.cellId,
initialPrompt: `Fix the SQL statement: ${error.msg}`,
});
},
},
];
}

return [];
}

Expand Down
4 changes: 1 addition & 3 deletions marimo/_server/ai/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
"markdown": [],
"sql": [
"The SQL must use duckdb syntax.",
'SQL cells start with df = mo.sql(f"""<your query>""") for DuckDB, or df = mo.sql(f"""<your query>""", engine=engine) for other SQL engines.',
"This will automatically display the result in the UI. You do not need to return the dataframe in the cell.",
],
}

Expand Down Expand Up @@ -194,7 +192,7 @@ def get_refactor_or_insert_notebook_cell_system_prompt(
if support_multiple_cells:
system_prompt += "\n\nAgain, just output code wrapped in cells. Each cell is wrapped in backticks with the appropriate language identifier (python, sql, markdown)."
else:
system_prompt += "\n\nAgain, just output the code itself."
system_prompt += f"\n\nAgain, just output the code itself and make sure to return the code as just {language}."

return system_prompt

Expand Down
14 changes: 0 additions & 14 deletions tests/_server/ai/snapshots/chat_system_prompts.txt
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,6 @@ chart

## Rules for sql:
1. The SQL must use duckdb syntax.
2. SQL cells start with df = mo.sql(f"""<your query>""") for DuckDB, or df = mo.sql(f"""<your query>""", engine=engine) for other SQL engines.
3. This will automatically display the result in the UI. You do not need to return the dataframe in the cell.

==================== with custom rules ====================

Expand Down Expand Up @@ -285,8 +283,6 @@ chart

## Rules for sql:
1. The SQL must use duckdb syntax.
2. SQL cells start with df = mo.sql(f"""<your query>""") for DuckDB, or df = mo.sql(f"""<your query>""", engine=engine) for other SQL engines.
3. This will automatically display the result in the UI. You do not need to return the dataframe in the cell.

## Additional rules:
Always be polite.
Expand Down Expand Up @@ -432,8 +428,6 @@ chart

## Rules for sql:
1. The SQL must use duckdb syntax.
2. SQL cells start with df = mo.sql(f"""<your query>""") for DuckDB, or df = mo.sql(f"""<your query>""", engine=engine) for other SQL engines.
3. This will automatically display the result in the UI. You do not need to return the dataframe in the cell.

## Available variables from other cells:
- variable: `var1`- variable: `var2`
Expand Down Expand Up @@ -579,8 +573,6 @@ chart

## Rules for sql:
1. The SQL must use duckdb syntax.
2. SQL cells start with df = mo.sql(f"""<your query>""") for DuckDB, or df = mo.sql(f"""<your query>""", engine=engine) for other SQL engines.
3. This will automatically display the result in the UI. You do not need to return the dataframe in the cell.

## Available variables from other cells:
- variable: `df`
Expand Down Expand Up @@ -732,8 +724,6 @@ chart

## Rules for sql:
1. The SQL must use duckdb syntax.
2. SQL cells start with df = mo.sql(f"""<your query>""") for DuckDB, or df = mo.sql(f"""<your query>""", engine=engine) for other SQL engines.
3. This will automatically display the result in the UI. You do not need to return the dataframe in the cell.

## Available schema:
- Table: df_1
Expand Down Expand Up @@ -893,8 +883,6 @@ chart

## Rules for sql:
1. The SQL must use duckdb syntax.
2. SQL cells start with df = mo.sql(f"""<your query>""") for DuckDB, or df = mo.sql(f"""<your query>""", engine=engine) for other SQL engines.
3. This will automatically display the result in the UI. You do not need to return the dataframe in the cell.

<code_from_other_cells>
import pandas as pd
Expand Down Expand Up @@ -1042,8 +1030,6 @@ chart

## Rules for sql:
1. The SQL must use duckdb syntax.
2. SQL cells start with df = mo.sql(f"""<your query>""") for DuckDB, or df = mo.sql(f"""<your query>""", engine=engine) for other SQL engines.
3. This will automatically display the result in the UI. You do not need to return the dataframe in the cell.

## Additional rules:
Always be polite.
Expand Down
26 changes: 11 additions & 15 deletions tests/_server/ai/snapshots/system_prompts.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Immediately start with the following format. Do NOT comment on the code, just ou
6. If an import already exists, do not import it again.
7. If a variable is already defined, use another name, or make it private by adding an underscore at the beginning.

Again, just output the code itself.
Again, just output the code itself and make sure to return the code as just python.

==================== markdown ====================

Expand All @@ -37,7 +37,7 @@ Immediately start with the following format. Do NOT comment on the code, just ou
{CELL_CODE}
```

Again, just output the code itself.
Again, just output the code itself and make sure to return the code as just markdown.

==================== sql ====================

Expand All @@ -54,10 +54,8 @@ Immediately start with the following format. Do NOT comment on the code, just ou

## Rules for sql
1. The SQL must use duckdb syntax.
2. SQL cells start with df = mo.sql(f"""<your query>""") for DuckDB, or df = mo.sql(f"""<your query>""", engine=engine) for other SQL engines.
3. This will automatically display the result in the UI. You do not need to return the dataframe in the cell.

Again, just output the code itself.
Again, just output the code itself and make sure to return the code as just sql.

==================== idk ====================

Expand All @@ -72,7 +70,7 @@ Immediately start with the following format. Do NOT comment on the code, just ou
{CELL_CODE}
```

Again, just output the code itself.
Again, just output the code itself and make sure to return the code as just idk.

==================== with custom rules ====================

Expand All @@ -99,7 +97,7 @@ Immediately start with the following format. Do NOT comment on the code, just ou
## Additional rules:
Always use type hints.

Again, just output the code itself.
Again, just output the code itself and make sure to return the code as just python.

==================== with context ====================

Expand Down Expand Up @@ -133,7 +131,7 @@ Immediately start with the following format. Do NOT comment on the code, just ou
- Sample values: Alice, Bob, Charlie


Again, just output the code itself.
Again, just output the code itself and make sure to return the code as just python.

==================== with is_insert=True ====================

Expand Down Expand Up @@ -164,7 +162,7 @@ Immediately start with the following format. Do NOT comment on the code, just ou
6. If an import already exists, do not import it again.
7. If a variable is already defined, use another name, or make it private by adding an underscore at the beginning.

Again, just output the code itself.
Again, just output the code itself and make sure to return the code as just python.

==================== with cell_code ====================

Expand Down Expand Up @@ -201,7 +199,7 @@ print('Hello, world!')
6. If an import already exists, do not import it again.
7. If a variable is already defined, use another name, or make it private by adding an underscore at the beginning.

Again, just output the code itself.
Again, just output the code itself and make sure to return the code as just python.

==================== with selected_text ====================

Expand Down Expand Up @@ -238,7 +236,7 @@ print('Hello, world!')
6. If an import already exists, do not import it again.
7. If a variable is already defined, use another name, or make it private by adding an underscore at the beginning.

Again, just output the code itself.
Again, just output the code itself and make sure to return the code as just python.

==================== with other_cell_codes ====================

Expand Down Expand Up @@ -279,7 +277,7 @@ import pandas as pd
import numpy as np
</code_from_other_cells>

Again, just output the code itself.
Again, just output the code itself and make sure to return the code as just python.

==================== with VariableContext objects ====================

Expand Down Expand Up @@ -312,7 +310,7 @@ Immediately start with the following format. Do NOT comment on the code, just ou
- value_preview: <Model object>


Again, just output the code itself.
Again, just output the code itself and make sure to return the code as just python.

==================== with support_multiple_cells=True ====================

Expand Down Expand Up @@ -350,8 +348,6 @@ Separate logic into multiple cells to keep the code organized and readable.

## Rules for sql:
1. The SQL must use duckdb syntax.
2. SQL cells start with df = mo.sql(f"""<your query>""") for DuckDB, or df = mo.sql(f"""<your query>""", engine=engine) for other SQL engines.
3. This will automatically display the result in the UI. You do not need to return the dataframe in the cell.

## Available variables from other cells:
- variable: `df`
Expand Down
Loading