diff --git a/frontend/src/__mocks__/requests.ts b/frontend/src/__mocks__/requests.ts index aa9a50a2897..aea49e35329 100644 --- a/frontend/src/__mocks__/requests.ts +++ b/frontend/src/__mocks__/requests.ts @@ -40,6 +40,7 @@ export const MockRequestClient = { previewSQLTable: vi.fn().mockResolvedValue({}), previewSQLTableList: vi.fn().mockResolvedValue({ tables: [] }), previewDataSourceConnection: vi.fn().mockResolvedValue({}), + validateSQL: vi.fn().mockResolvedValue({}), openFile: vi.fn().mockResolvedValue({}), getUsageStats: vi.fn().mockResolvedValue({}), sendPdb: vi.fn().mockResolvedValue({}), diff --git a/frontend/src/components/editor/Cell.tsx b/frontend/src/components/editor/Cell.tsx index d0b3bbfcd5b..27b1ee579fd 100644 --- a/frontend/src/components/editor/Cell.tsx +++ b/frontend/src/components/editor/Cell.tsx @@ -80,6 +80,7 @@ import { useDeleteCellCallback } from "./cell/useDeleteCell"; import { useRunCell } from "./cell/useRunCells"; import { HideCodeButton } from "./code/readonly-python-code"; import { cellDomProps } from "./common"; +import { SqlValidationErrorBanner } from "./errors/sql-validation-errors"; import { useCellNavigationProps } from "./navigation/navigation"; import { useTemporarilyShownCode, @@ -653,6 +654,7 @@ const EditableCellComponent = ({ )} + {cellOutput === "below" && outputArea} {cellRuntime.serialization && (
diff --git a/frontend/src/components/editor/errors/sql-validation-errors.tsx b/frontend/src/components/editor/errors/sql-validation-errors.tsx new file mode 100644 index 00000000000..835be893215 --- /dev/null +++ b/frontend/src/components/editor/errors/sql-validation-errors.tsx @@ -0,0 +1,34 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import { AlertCircleIcon } from "lucide-react"; +import type { CellId } from "@/core/cells/ids"; +import { useSqlValidationErrorsForCell } from "@/core/codemirror/language/languages/sql/validation-errors"; + +export const SqlValidationErrorBanner = ({ cellId }: { cellId: CellId }) => { + const error = useSqlValidationErrorsForCell(cellId); + + if (!error) { + return; + } + + return ( +
+
+ +

+ {error.errorType}:{" "} + {error.errorMessage} +

+
+ + {error.codeblock && ( +
+          {error.codeblock}
+        
+ )} +
+ ); +}; diff --git a/frontend/src/core/codemirror/language/__tests__/extension.test.ts b/frontend/src/core/codemirror/language/__tests__/extension.test.ts index 4f0188916f1..e6d04c16729 100644 --- a/frontend/src/core/codemirror/language/__tests__/extension.test.ts +++ b/frontend/src/core/codemirror/language/__tests__/extension.test.ts @@ -14,6 +14,7 @@ import { languageAdapterState, switchLanguage, } from "../extension"; +import { exportedForTesting as sqlValidationErrorsForTesting } from "../languages/sql/validation-errors"; import { languageMetadataField } from "../metadata"; let view: EditorView | null = null; @@ -258,3 +259,26 @@ describe("switchLanguage", () => { }); }); }); + +describe("sqlValidationErrors", () => { + const { splitErrorMessage } = sqlValidationErrorsForTesting; + + describe("split error message", () => { + it("should split the error message into error type and error message", () => { + const error = "SyntaxError: SELECT * FROM df"; + const { errorType, errorMessage } = splitErrorMessage(error); + expect(errorType).toBe("SyntaxError"); + expect(errorMessage).toBe("SELECT * FROM df"); + }); + + it("should handle multiple colons", () => { + const error = + "SyntaxError: SELECT * FROM df:SyntaxError: SELECT * FROM df"; + const { errorType, errorMessage } = splitErrorMessage(error); + expect(errorType).toBe("SyntaxError"); + expect(errorMessage).toBe( + "SELECT * FROM df:SyntaxError: SELECT * FROM df", + ); + }); + }); +}); diff --git a/frontend/src/core/codemirror/language/__tests__/sql-validation.test.ts b/frontend/src/core/codemirror/language/__tests__/sql-validation.test.ts new file mode 100644 index 00000000000..9edd9c1e738 --- /dev/null +++ b/frontend/src/core/codemirror/language/__tests__/sql-validation.test.ts @@ -0,0 +1,133 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import { describe, expect, it } from "vitest"; +import { exportedForTesting } from "../languages/sql/validation-errors"; + +describe("Error Message Splitting", () => { + it("should handle error message splitting correctly", () => { + const { splitErrorMessage } = exportedForTesting; + + const result1 = splitErrorMessage("Syntax error: unexpected token"); + expect(result1.errorType).toBe("Syntax error"); + expect(result1.errorMessage).toBe("unexpected token"); + + const result2 = splitErrorMessage("Multiple: colons: in error"); + expect(result2.errorType).toBe("Multiple"); + expect(result2.errorMessage).toBe("colons: in error"); + + const result3 = splitErrorMessage("No colon error"); + expect(result3.errorType).toBe("No colon error"); + expect(result3.errorMessage).toBe(""); + }); +}); + +describe("DuckDB Error Handling", () => { + it("should extract codeblock from error with LINE information", () => { + const { handleDuckdbError } = exportedForTesting; + + const error = + 'Binder Error: Referenced column "attacks" not found in FROM clause! Candidate bindings: "Attack", "Total" LINE 1:... from pokemon WHERE \'type_2\' = 32 and attack = 32 and not attacks = \'hi\' ^'; + + const result = handleDuckdbError(error); + + expect(result.errorType).toBe("Binder Error"); + expect(result.errorMessage).toBe( + 'Referenced column "attacks" not found in FROM clause! Candidate bindings: "Attack", "Total"', + ); + expect(result.codeblock).toBe( + "LINE 1:... from pokemon WHERE 'type_2' = 32 and attack = 32 and not attacks = 'hi' ^", + ); + }); + + it("should handle error without LINE information", () => { + const { handleDuckdbError } = exportedForTesting; + + const error = "Syntax Error: Invalid syntax near WHERE"; + + const result = handleDuckdbError(error); + + expect(result.errorType).toBe("Syntax Error"); + expect(result.errorMessage).toBe("Invalid syntax near WHERE"); + expect(result.codeblock).toBeUndefined(); + }); + + it("should handle error with LINE at the beginning", () => { + const { handleDuckdbError } = exportedForTesting; + + const error = "LINE 1: SELECT * FROM table WHERE invalid_column = 1 ^"; + + const result = handleDuckdbError(error); + + expect(result.errorType).toBe("LINE 1"); + expect(result.errorMessage).toBe( + "SELECT * FROM table WHERE invalid_column = 1 ^", + ); + expect(result.codeblock).toBeUndefined(); + }); + + it("should handle error with multiple LINE occurrences", () => { + const { handleDuckdbError } = exportedForTesting; + + const error = + "Error: Something went wrong LINE 1: SELECT * FROM table WHERE invalid_column = 1 ^"; + + const result = handleDuckdbError(error); + + expect(result.errorType).toBe("Error"); + expect(result.errorMessage).toBe("Something went wrong"); + expect(result.codeblock).toBe( + "LINE 1: SELECT * FROM table WHERE invalid_column = 1 ^", + ); + }); + + it("should handle complex error with nested quotes", () => { + const { handleDuckdbError } = exportedForTesting; + + const error = + "Binder Error: Column \"name\" not found! LINE 1: SELECT * FROM users WHERE name = 'John' AND age > 25 ^"; + + const result = handleDuckdbError(error); + + expect(result.errorType).toBe("Binder Error"); + expect(result.errorMessage).toBe('Column "name" not found!'); + expect(result.codeblock).toBe( + "LINE 1: SELECT * FROM users WHERE name = 'John' AND age > 25 ^", + ); + }); + + it("should handle error with LINE but no caret", () => { + const { handleDuckdbError } = exportedForTesting; + + const error = "Error: Invalid query LINE 1: SELECT * FROM table"; + + const result = handleDuckdbError(error); + + expect(result.errorType).toBe("Error"); + expect(result.errorMessage).toBe("Invalid query"); + expect(result.codeblock).toBe("LINE 1: SELECT * FROM table"); + }); + + it("should trim whitespace from codeblock", () => { + const { handleDuckdbError } = exportedForTesting; + + const error = "Error: Something wrong LINE 1: SELECT * FROM table ^ "; + + const result = handleDuckdbError(error); + + expect(result.errorType).toBe("Error"); + expect(result.errorMessage).toBe("Something wrong"); + expect(result.codeblock).toBe("LINE 1: SELECT * FROM table ^"); + }); + + it("should handle empty error message", () => { + const { handleDuckdbError } = exportedForTesting; + + const error = ""; + + const result = handleDuckdbError(error); + + expect(result.errorType).toBe(""); + expect(result.errorMessage).toBe(""); + expect(result.codeblock).toBeUndefined(); + }); +}); diff --git a/frontend/src/core/codemirror/language/languages/sql/sql-mode.ts b/frontend/src/core/codemirror/language/languages/sql/sql-mode.ts new file mode 100644 index 00000000000..31587629393 --- /dev/null +++ b/frontend/src/core/codemirror/language/languages/sql/sql-mode.ts @@ -0,0 +1,20 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import { useAtom } from "jotai"; +import { atomWithStorage } from "jotai/utils"; +import { store } from "@/core/state/jotai"; + +const BASE_KEY = "marimo:notebook-sql-mode"; + +export type SQLMode = "validate" | "default"; + +const sqlModeAtom = atomWithStorage(BASE_KEY, "default"); + +export function useSQLMode() { + const [sqlMode, setSQLMode] = useAtom(sqlModeAtom); + return { sqlMode, setSQLMode }; +} + +export function getSQLMode() { + return store.get(sqlModeAtom); +} diff --git a/frontend/src/core/codemirror/language/languages/sql/sql.ts b/frontend/src/core/codemirror/language/languages/sql/sql.ts index bc6c6e0fb82..3ef15ed407e 100644 --- a/frontend/src/core/codemirror/language/languages/sql/sql.ts +++ b/frontend/src/core/codemirror/language/languages/sql/sql.ts @@ -5,7 +5,7 @@ import { insertTab } from "@codemirror/commands"; import { type SQLDialect, type SQLNamespace, sql } from "@codemirror/lang-sql"; import type { EditorState, Extension } from "@codemirror/state"; import { Compartment } from "@codemirror/state"; -import { type EditorView, keymap } from "@codemirror/view"; +import { EditorView, keymap } from "@codemirror/view"; import type { SyntaxNode, TreeCursor } from "@lezer/common"; import { parser } from "@lezer/python"; import { @@ -16,12 +16,19 @@ import { } from "@marimo-team/codemirror-sql"; import { DuckDBDialect } from "@marimo-team/codemirror-sql/dialects"; import dedent from "string-dedent"; +import { cellIdState } from "@/core/codemirror/cells/state"; import { getFeatureFlag } from "@/core/config/feature-flag"; import { dataSourceConnectionsAtom, setLatestEngineSelected, } from "@/core/datasets/data-source-connections"; -import { type ConnectionName, DUCKDB_ENGINE } from "@/core/datasets/engines"; +import { + type ConnectionName, + DUCKDB_ENGINE, + INTERNAL_SQL_ENGINES, +} from "@/core/datasets/engines"; +import { ValidateSQL } from "@/core/datasets/request-registry"; +import type { ValidateSQLResult } from "@/core/kernel/messages"; import { store } from "@/core/state/jotai"; import { resolvedThemeAtom } from "@/theme/useTheme"; import { Logger } from "@/utils/Logger"; @@ -37,6 +44,11 @@ import { tablesCompletionSource, } from "./completion-sources"; import { SCHEMA_CACHE } from "./completion-store"; +import { getSQLMode } from "./sql-mode"; +import { + clearSqlValidationError, + setSqlValidationError, +} from "./validation-errors"; const DEFAULT_DIALECT = DuckDBDialect; const DEFAULT_PARSER_DIALECT = "DuckDB"; @@ -64,12 +76,15 @@ export class SQLLanguageAdapter { readonly type = "sql"; sqlLinterEnabled: boolean; + sqlModeEnabled: boolean; constructor() { try { this.sqlLinterEnabled = getFeatureFlag("sql_linter"); + this.sqlModeEnabled = getFeatureFlag("sql_mode"); } catch { this.sqlLinterEnabled = false; + this.sqlModeEnabled = false; } } @@ -265,6 +280,10 @@ export class SQLLanguageAdapter ); } + if (this.sqlModeEnabled) { + extensions.push(sqlValidationExtension()); + } + return extensions; } } @@ -315,9 +334,14 @@ function getSchema(view: EditorView): SQLNamespace { function guessParserDialect(state: EditorState): ParserDialects | null { const metadata = getSQLMetadata(state); const connectionName = metadata.engine; + return connectionNameToParserDialect(connectionName); +} + +function connectionNameToParserDialect( + connectionName: ConnectionName, +): ParserDialects | null { const dialect = SCHEMA_CACHE.getInternalDialect(connectionName)?.toLowerCase(); - switch (dialect) { case "postgresql": case "postgres": @@ -543,3 +567,66 @@ function safeDedent(code: string): string { return code; } } + +function sqlValidationExtension(): Extension { + let debounceTimeout: NodeJS.Timeout | null = null; + let lastValidationRequest: string | null = null; + + return EditorView.updateListener.of((update) => { + const sqlMode = getSQLMode(); + if (sqlMode !== "validate") { + return; + } + + const metadata = getSQLMetadata(update.state); + const connectionName = metadata.engine; + if (!INTERNAL_SQL_ENGINES.has(connectionName)) { + // Currently only internal engines are supported + return; + } + + if (!update.docChanged) { + return; + } + + const doc = update.state.doc; + const sqlContent = doc.toString(); + + // Clear existing timeout + if (debounceTimeout) { + clearTimeout(debounceTimeout); + } + + // Debounce the validation call + debounceTimeout = setTimeout(async () => { + // Skip if content hasn't changed since last validation + if (lastValidationRequest === sqlContent) { + return; + } + + lastValidationRequest = sqlContent; + const cellId = update.view.state.facet(cellIdState); + + if (sqlContent === "") { + clearSqlValidationError(cellId); + return; + } + + try { + const result: ValidateSQLResult = await ValidateSQL.request({ + engine: connectionName, + query: sqlContent, + }); + + if (result.error) { + const dialect = connectionNameToParserDialect(connectionName); + setSqlValidationError({ cellId, error: result.error, dialect }); + } else { + clearSqlValidationError(cellId); + } + } catch (error) { + Logger.warn("Failed to validate SQL", { error }); + } + }, 300); + }); +} diff --git a/frontend/src/core/codemirror/language/languages/sql/validation-errors.ts b/frontend/src/core/codemirror/language/languages/sql/validation-errors.ts new file mode 100644 index 00000000000..052eb23f6f1 --- /dev/null +++ b/frontend/src/core/codemirror/language/languages/sql/validation-errors.ts @@ -0,0 +1,79 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import type { SupportedDialects } from "@marimo-team/codemirror-sql"; +import { atom, useAtomValue } from "jotai"; +import type { CellId } from "@/core/cells/ids"; +import { store } from "@/core/state/jotai"; + +export interface SQLValidationError { + errorType: string; + errorMessage: string; + codeblock?: string; // Code block that caused the error +} + +type CellToSQLErrors = Map; + +export const sqlValidationErrorsAtom = atom( + new Map(), +); + +export const useSqlValidationErrorsForCell = (cellId: CellId) => { + const sqlValidationErrors = useAtomValue(sqlValidationErrorsAtom); + return sqlValidationErrors.get(cellId); +}; + +export function clearSqlValidationError(cellId: CellId) { + const sqlValidationErrors = store.get(sqlValidationErrorsAtom); + const newErrors = new Map(sqlValidationErrors); + newErrors.delete(cellId); + store.set(sqlValidationErrorsAtom, newErrors); +} + +export function setSqlValidationError({ + cellId, + error, + dialect, +}: { + cellId: CellId; + error: string; + dialect: SupportedDialects | null; +}) { + const sqlValidationErrors = store.get(sqlValidationErrorsAtom); + const newErrors = new Map(sqlValidationErrors); + + const errorResult: SQLValidationError = + dialect === "DuckDB" ? handleDuckdbError(error) : splitErrorMessage(error); + + newErrors.set(cellId, errorResult); + store.set(sqlValidationErrorsAtom, newErrors); +} + +function handleDuckdbError(error: string): SQLValidationError { + const { errorType, errorMessage } = splitErrorMessage(error); + let newErrorMessage = errorMessage; + + // Extract the LINE and the rest of the message as codeblock, keep errorMessage as whatever is before + let codeblock: string | undefined; + const lineIndex = errorMessage.indexOf("LINE "); + if (lineIndex !== -1) { + codeblock = errorMessage.slice(Math.max(0, lineIndex)).trim(); + newErrorMessage = errorMessage.slice(0, Math.max(0, lineIndex)).trim(); + } + + return { + errorType, + errorMessage: newErrorMessage, + codeblock, + }; +} + +function splitErrorMessage(error: string) { + const errorType = error.split(":")[0].trim(); + const errorMessage = error.split(":").slice(1).join(":").trim(); + return { errorType, errorMessage }; +} + +export const exportedForTesting = { + splitErrorMessage, + handleDuckdbError, +}; diff --git a/frontend/src/core/codemirror/language/panel/panel.tsx b/frontend/src/core/codemirror/language/panel/panel.tsx index a5404143a7a..c9147153210 100644 --- a/frontend/src/core/codemirror/language/panel/panel.tsx +++ b/frontend/src/core/codemirror/language/panel/panel.tsx @@ -5,7 +5,8 @@ import { Button } from "@/components/ui/button"; import { Checkbox } from "@/components/ui/checkbox"; import { Tooltip, TooltipProvider } from "@/components/ui/tooltip"; import { normalizeName } from "@/core/cells/names"; -import type { ConnectionName } from "@/core/datasets/engines"; +import { getFeatureFlag } from "@/core/config/feature-flag"; +import { type ConnectionName, DUCKDB_ENGINE } from "@/core/datasets/engines"; import { useAutoGrowInputProps } from "@/hooks/useAutoGrowInputProps"; import { formatSQL } from "../../format"; import { languageAdapterState } from "../extension"; @@ -22,7 +23,7 @@ import { import type { LanguageMetadataOf } from "../types"; import type { QuotePrefixKind } from "../utils/quotes"; import { getQuotePrefix, MarkdownQuotePrefixTooltip } from "./markdown"; -import { SQLEngineSelect } from "./sql"; +import { SQLEngineSelect, SQLModeSelect } from "./sql"; const Divider = () =>
; @@ -70,6 +71,8 @@ export const LanguagePanelComponent: React.FC<{ updateSQLDialectFromConnection(view, engine); }; + const sqlModeEnabled = getFeatureFlag("sql_mode"); + actions = (