diff --git a/frontend/src/__mocks__/requests.ts b/frontend/src/__mocks__/requests.ts
index aa9a50a2897..aea49e35329 100644
--- a/frontend/src/__mocks__/requests.ts
+++ b/frontend/src/__mocks__/requests.ts
@@ -40,6 +40,7 @@ export const MockRequestClient = {
previewSQLTable: vi.fn().mockResolvedValue({}),
previewSQLTableList: vi.fn().mockResolvedValue({ tables: [] }),
previewDataSourceConnection: vi.fn().mockResolvedValue({}),
+ validateSQL: vi.fn().mockResolvedValue({}),
openFile: vi.fn().mockResolvedValue({}),
getUsageStats: vi.fn().mockResolvedValue({}),
sendPdb: vi.fn().mockResolvedValue({}),
diff --git a/frontend/src/components/editor/Cell.tsx b/frontend/src/components/editor/Cell.tsx
index d0b3bbfcd5b..27b1ee579fd 100644
--- a/frontend/src/components/editor/Cell.tsx
+++ b/frontend/src/components/editor/Cell.tsx
@@ -80,6 +80,7 @@ import { useDeleteCellCallback } from "./cell/useDeleteCell";
import { useRunCell } from "./cell/useRunCells";
import { HideCodeButton } from "./code/readonly-python-code";
import { cellDomProps } from "./common";
+import { SqlValidationErrorBanner } from "./errors/sql-validation-errors";
import { useCellNavigationProps } from "./navigation/navigation";
import {
useTemporarilyShownCode,
@@ -653,6 +654,7 @@ const EditableCellComponent = ({
)}
+
{cellOutput === "below" && outputArea}
{cellRuntime.serialization && (
diff --git a/frontend/src/components/editor/errors/sql-validation-errors.tsx b/frontend/src/components/editor/errors/sql-validation-errors.tsx
new file mode 100644
index 00000000000..835be893215
--- /dev/null
+++ b/frontend/src/components/editor/errors/sql-validation-errors.tsx
@@ -0,0 +1,34 @@
+/* Copyright 2024 Marimo. All rights reserved. */
+
+import { AlertCircleIcon } from "lucide-react";
+import type { CellId } from "@/core/cells/ids";
+import { useSqlValidationErrorsForCell } from "@/core/codemirror/language/languages/sql/validation-errors";
+
+export const SqlValidationErrorBanner = ({ cellId }: { cellId: CellId }) => {
+ const error = useSqlValidationErrorsForCell(cellId);
+
+ if (!error) {
+ return;
+ }
+
+ return (
+
+
+
+
+ {error.errorType}:{" "}
+ {error.errorMessage}
+
+
+
+ {error.codeblock && (
+
+ {error.codeblock}
+
+ )}
+
+ );
+};
diff --git a/frontend/src/core/codemirror/language/__tests__/extension.test.ts b/frontend/src/core/codemirror/language/__tests__/extension.test.ts
index 4f0188916f1..e6d04c16729 100644
--- a/frontend/src/core/codemirror/language/__tests__/extension.test.ts
+++ b/frontend/src/core/codemirror/language/__tests__/extension.test.ts
@@ -14,6 +14,7 @@ import {
languageAdapterState,
switchLanguage,
} from "../extension";
+import { exportedForTesting as sqlValidationErrorsForTesting } from "../languages/sql/validation-errors";
import { languageMetadataField } from "../metadata";
let view: EditorView | null = null;
@@ -258,3 +259,26 @@ describe("switchLanguage", () => {
});
});
});
+
+describe("sqlValidationErrors", () => {
+ const { splitErrorMessage } = sqlValidationErrorsForTesting;
+
+ describe("split error message", () => {
+ it("should split the error message into error type and error message", () => {
+ const error = "SyntaxError: SELECT * FROM df";
+ const { errorType, errorMessage } = splitErrorMessage(error);
+ expect(errorType).toBe("SyntaxError");
+ expect(errorMessage).toBe("SELECT * FROM df");
+ });
+
+ it("should handle multiple colons", () => {
+ const error =
+ "SyntaxError: SELECT * FROM df:SyntaxError: SELECT * FROM df";
+ const { errorType, errorMessage } = splitErrorMessage(error);
+ expect(errorType).toBe("SyntaxError");
+ expect(errorMessage).toBe(
+ "SELECT * FROM df:SyntaxError: SELECT * FROM df",
+ );
+ });
+ });
+});
diff --git a/frontend/src/core/codemirror/language/__tests__/sql-validation.test.ts b/frontend/src/core/codemirror/language/__tests__/sql-validation.test.ts
new file mode 100644
index 00000000000..9edd9c1e738
--- /dev/null
+++ b/frontend/src/core/codemirror/language/__tests__/sql-validation.test.ts
@@ -0,0 +1,133 @@
+/* Copyright 2024 Marimo. All rights reserved. */
+
+import { describe, expect, it } from "vitest";
+import { exportedForTesting } from "../languages/sql/validation-errors";
+
+describe("Error Message Splitting", () => {
+ it("should handle error message splitting correctly", () => {
+ const { splitErrorMessage } = exportedForTesting;
+
+ const result1 = splitErrorMessage("Syntax error: unexpected token");
+ expect(result1.errorType).toBe("Syntax error");
+ expect(result1.errorMessage).toBe("unexpected token");
+
+ const result2 = splitErrorMessage("Multiple: colons: in error");
+ expect(result2.errorType).toBe("Multiple");
+ expect(result2.errorMessage).toBe("colons: in error");
+
+ const result3 = splitErrorMessage("No colon error");
+ expect(result3.errorType).toBe("No colon error");
+ expect(result3.errorMessage).toBe("");
+ });
+});
+
+describe("DuckDB Error Handling", () => {
+ it("should extract codeblock from error with LINE information", () => {
+ const { handleDuckdbError } = exportedForTesting;
+
+ const error =
+ 'Binder Error: Referenced column "attacks" not found in FROM clause! Candidate bindings: "Attack", "Total" LINE 1:... from pokemon WHERE \'type_2\' = 32 and attack = 32 and not attacks = \'hi\' ^';
+
+ const result = handleDuckdbError(error);
+
+ expect(result.errorType).toBe("Binder Error");
+ expect(result.errorMessage).toBe(
+ 'Referenced column "attacks" not found in FROM clause! Candidate bindings: "Attack", "Total"',
+ );
+ expect(result.codeblock).toBe(
+ "LINE 1:... from pokemon WHERE 'type_2' = 32 and attack = 32 and not attacks = 'hi' ^",
+ );
+ });
+
+ it("should handle error without LINE information", () => {
+ const { handleDuckdbError } = exportedForTesting;
+
+ const error = "Syntax Error: Invalid syntax near WHERE";
+
+ const result = handleDuckdbError(error);
+
+ expect(result.errorType).toBe("Syntax Error");
+ expect(result.errorMessage).toBe("Invalid syntax near WHERE");
+ expect(result.codeblock).toBeUndefined();
+ });
+
+ it("should handle error with LINE at the beginning", () => {
+ const { handleDuckdbError } = exportedForTesting;
+
+ const error = "LINE 1: SELECT * FROM table WHERE invalid_column = 1 ^";
+
+ const result = handleDuckdbError(error);
+
+ expect(result.errorType).toBe("LINE 1");
+ expect(result.errorMessage).toBe(
+ "SELECT * FROM table WHERE invalid_column = 1 ^",
+ );
+ expect(result.codeblock).toBeUndefined();
+ });
+
+ it("should handle error with multiple LINE occurrences", () => {
+ const { handleDuckdbError } = exportedForTesting;
+
+ const error =
+ "Error: Something went wrong LINE 1: SELECT * FROM table WHERE invalid_column = 1 ^";
+
+ const result = handleDuckdbError(error);
+
+ expect(result.errorType).toBe("Error");
+ expect(result.errorMessage).toBe("Something went wrong");
+ expect(result.codeblock).toBe(
+ "LINE 1: SELECT * FROM table WHERE invalid_column = 1 ^",
+ );
+ });
+
+ it("should handle complex error with nested quotes", () => {
+ const { handleDuckdbError } = exportedForTesting;
+
+ const error =
+ "Binder Error: Column \"name\" not found! LINE 1: SELECT * FROM users WHERE name = 'John' AND age > 25 ^";
+
+ const result = handleDuckdbError(error);
+
+ expect(result.errorType).toBe("Binder Error");
+ expect(result.errorMessage).toBe('Column "name" not found!');
+ expect(result.codeblock).toBe(
+ "LINE 1: SELECT * FROM users WHERE name = 'John' AND age > 25 ^",
+ );
+ });
+
+ it("should handle error with LINE but no caret", () => {
+ const { handleDuckdbError } = exportedForTesting;
+
+ const error = "Error: Invalid query LINE 1: SELECT * FROM table";
+
+ const result = handleDuckdbError(error);
+
+ expect(result.errorType).toBe("Error");
+ expect(result.errorMessage).toBe("Invalid query");
+ expect(result.codeblock).toBe("LINE 1: SELECT * FROM table");
+ });
+
+ it("should trim whitespace from codeblock", () => {
+ const { handleDuckdbError } = exportedForTesting;
+
+ const error = "Error: Something wrong LINE 1: SELECT * FROM table ^ ";
+
+ const result = handleDuckdbError(error);
+
+ expect(result.errorType).toBe("Error");
+ expect(result.errorMessage).toBe("Something wrong");
+ expect(result.codeblock).toBe("LINE 1: SELECT * FROM table ^");
+ });
+
+ it("should handle empty error message", () => {
+ const { handleDuckdbError } = exportedForTesting;
+
+ const error = "";
+
+ const result = handleDuckdbError(error);
+
+ expect(result.errorType).toBe("");
+ expect(result.errorMessage).toBe("");
+ expect(result.codeblock).toBeUndefined();
+ });
+});
diff --git a/frontend/src/core/codemirror/language/languages/sql/sql-mode.ts b/frontend/src/core/codemirror/language/languages/sql/sql-mode.ts
new file mode 100644
index 00000000000..31587629393
--- /dev/null
+++ b/frontend/src/core/codemirror/language/languages/sql/sql-mode.ts
@@ -0,0 +1,20 @@
+/* Copyright 2024 Marimo. All rights reserved. */
+
+import { useAtom } from "jotai";
+import { atomWithStorage } from "jotai/utils";
+import { store } from "@/core/state/jotai";
+
+const BASE_KEY = "marimo:notebook-sql-mode";
+
+export type SQLMode = "validate" | "default";
+
+const sqlModeAtom = atomWithStorage
(BASE_KEY, "default");
+
+export function useSQLMode() {
+ const [sqlMode, setSQLMode] = useAtom(sqlModeAtom);
+ return { sqlMode, setSQLMode };
+}
+
+export function getSQLMode() {
+ return store.get(sqlModeAtom);
+}
diff --git a/frontend/src/core/codemirror/language/languages/sql/sql.ts b/frontend/src/core/codemirror/language/languages/sql/sql.ts
index bc6c6e0fb82..3ef15ed407e 100644
--- a/frontend/src/core/codemirror/language/languages/sql/sql.ts
+++ b/frontend/src/core/codemirror/language/languages/sql/sql.ts
@@ -5,7 +5,7 @@ import { insertTab } from "@codemirror/commands";
import { type SQLDialect, type SQLNamespace, sql } from "@codemirror/lang-sql";
import type { EditorState, Extension } from "@codemirror/state";
import { Compartment } from "@codemirror/state";
-import { type EditorView, keymap } from "@codemirror/view";
+import { EditorView, keymap } from "@codemirror/view";
import type { SyntaxNode, TreeCursor } from "@lezer/common";
import { parser } from "@lezer/python";
import {
@@ -16,12 +16,19 @@ import {
} from "@marimo-team/codemirror-sql";
import { DuckDBDialect } from "@marimo-team/codemirror-sql/dialects";
import dedent from "string-dedent";
+import { cellIdState } from "@/core/codemirror/cells/state";
import { getFeatureFlag } from "@/core/config/feature-flag";
import {
dataSourceConnectionsAtom,
setLatestEngineSelected,
} from "@/core/datasets/data-source-connections";
-import { type ConnectionName, DUCKDB_ENGINE } from "@/core/datasets/engines";
+import {
+ type ConnectionName,
+ DUCKDB_ENGINE,
+ INTERNAL_SQL_ENGINES,
+} from "@/core/datasets/engines";
+import { ValidateSQL } from "@/core/datasets/request-registry";
+import type { ValidateSQLResult } from "@/core/kernel/messages";
import { store } from "@/core/state/jotai";
import { resolvedThemeAtom } from "@/theme/useTheme";
import { Logger } from "@/utils/Logger";
@@ -37,6 +44,11 @@ import {
tablesCompletionSource,
} from "./completion-sources";
import { SCHEMA_CACHE } from "./completion-store";
+import { getSQLMode } from "./sql-mode";
+import {
+ clearSqlValidationError,
+ setSqlValidationError,
+} from "./validation-errors";
const DEFAULT_DIALECT = DuckDBDialect;
const DEFAULT_PARSER_DIALECT = "DuckDB";
@@ -64,12 +76,15 @@ export class SQLLanguageAdapter
{
readonly type = "sql";
sqlLinterEnabled: boolean;
+ sqlModeEnabled: boolean;
constructor() {
try {
this.sqlLinterEnabled = getFeatureFlag("sql_linter");
+ this.sqlModeEnabled = getFeatureFlag("sql_mode");
} catch {
this.sqlLinterEnabled = false;
+ this.sqlModeEnabled = false;
}
}
@@ -265,6 +280,10 @@ export class SQLLanguageAdapter
);
}
+ if (this.sqlModeEnabled) {
+ extensions.push(sqlValidationExtension());
+ }
+
return extensions;
}
}
@@ -315,9 +334,14 @@ function getSchema(view: EditorView): SQLNamespace {
function guessParserDialect(state: EditorState): ParserDialects | null {
const metadata = getSQLMetadata(state);
const connectionName = metadata.engine;
+ return connectionNameToParserDialect(connectionName);
+}
+
+function connectionNameToParserDialect(
+ connectionName: ConnectionName,
+): ParserDialects | null {
const dialect =
SCHEMA_CACHE.getInternalDialect(connectionName)?.toLowerCase();
-
switch (dialect) {
case "postgresql":
case "postgres":
@@ -543,3 +567,66 @@ function safeDedent(code: string): string {
return code;
}
}
+
+function sqlValidationExtension(): Extension {
+ let debounceTimeout: NodeJS.Timeout | null = null;
+ let lastValidationRequest: string | null = null;
+
+ return EditorView.updateListener.of((update) => {
+ const sqlMode = getSQLMode();
+ if (sqlMode !== "validate") {
+ return;
+ }
+
+ const metadata = getSQLMetadata(update.state);
+ const connectionName = metadata.engine;
+ if (!INTERNAL_SQL_ENGINES.has(connectionName)) {
+ // Currently only internal engines are supported
+ return;
+ }
+
+ if (!update.docChanged) {
+ return;
+ }
+
+ const doc = update.state.doc;
+ const sqlContent = doc.toString();
+
+ // Clear existing timeout
+ if (debounceTimeout) {
+ clearTimeout(debounceTimeout);
+ }
+
+ // Debounce the validation call
+ debounceTimeout = setTimeout(async () => {
+ // Skip if content hasn't changed since last validation
+ if (lastValidationRequest === sqlContent) {
+ return;
+ }
+
+ lastValidationRequest = sqlContent;
+ const cellId = update.view.state.facet(cellIdState);
+
+ if (sqlContent === "") {
+ clearSqlValidationError(cellId);
+ return;
+ }
+
+ try {
+ const result: ValidateSQLResult = await ValidateSQL.request({
+ engine: connectionName,
+ query: sqlContent,
+ });
+
+ if (result.error) {
+ const dialect = connectionNameToParserDialect(connectionName);
+ setSqlValidationError({ cellId, error: result.error, dialect });
+ } else {
+ clearSqlValidationError(cellId);
+ }
+ } catch (error) {
+ Logger.warn("Failed to validate SQL", { error });
+ }
+ }, 300);
+ });
+}
diff --git a/frontend/src/core/codemirror/language/languages/sql/validation-errors.ts b/frontend/src/core/codemirror/language/languages/sql/validation-errors.ts
new file mode 100644
index 00000000000..052eb23f6f1
--- /dev/null
+++ b/frontend/src/core/codemirror/language/languages/sql/validation-errors.ts
@@ -0,0 +1,79 @@
+/* Copyright 2024 Marimo. All rights reserved. */
+
+import type { SupportedDialects } from "@marimo-team/codemirror-sql";
+import { atom, useAtomValue } from "jotai";
+import type { CellId } from "@/core/cells/ids";
+import { store } from "@/core/state/jotai";
+
+export interface SQLValidationError {
+ errorType: string;
+ errorMessage: string;
+ codeblock?: string; // Code block that caused the error
+}
+
+type CellToSQLErrors = Map;
+
+export const sqlValidationErrorsAtom = atom(
+ new Map(),
+);
+
+export const useSqlValidationErrorsForCell = (cellId: CellId) => {
+ const sqlValidationErrors = useAtomValue(sqlValidationErrorsAtom);
+ return sqlValidationErrors.get(cellId);
+};
+
+export function clearSqlValidationError(cellId: CellId) {
+ const sqlValidationErrors = store.get(sqlValidationErrorsAtom);
+ const newErrors = new Map(sqlValidationErrors);
+ newErrors.delete(cellId);
+ store.set(sqlValidationErrorsAtom, newErrors);
+}
+
+export function setSqlValidationError({
+ cellId,
+ error,
+ dialect,
+}: {
+ cellId: CellId;
+ error: string;
+ dialect: SupportedDialects | null;
+}) {
+ const sqlValidationErrors = store.get(sqlValidationErrorsAtom);
+ const newErrors = new Map(sqlValidationErrors);
+
+ const errorResult: SQLValidationError =
+ dialect === "DuckDB" ? handleDuckdbError(error) : splitErrorMessage(error);
+
+ newErrors.set(cellId, errorResult);
+ store.set(sqlValidationErrorsAtom, newErrors);
+}
+
+function handleDuckdbError(error: string): SQLValidationError {
+ const { errorType, errorMessage } = splitErrorMessage(error);
+ let newErrorMessage = errorMessage;
+
+ // Extract the LINE and the rest of the message as codeblock, keep errorMessage as whatever is before
+ let codeblock: string | undefined;
+ const lineIndex = errorMessage.indexOf("LINE ");
+ if (lineIndex !== -1) {
+ codeblock = errorMessage.slice(Math.max(0, lineIndex)).trim();
+ newErrorMessage = errorMessage.slice(0, Math.max(0, lineIndex)).trim();
+ }
+
+ return {
+ errorType,
+ errorMessage: newErrorMessage,
+ codeblock,
+ };
+}
+
+function splitErrorMessage(error: string) {
+ const errorType = error.split(":")[0].trim();
+ const errorMessage = error.split(":").slice(1).join(":").trim();
+ return { errorType, errorMessage };
+}
+
+export const exportedForTesting = {
+ splitErrorMessage,
+ handleDuckdbError,
+};
diff --git a/frontend/src/core/codemirror/language/panel/panel.tsx b/frontend/src/core/codemirror/language/panel/panel.tsx
index a5404143a7a..c9147153210 100644
--- a/frontend/src/core/codemirror/language/panel/panel.tsx
+++ b/frontend/src/core/codemirror/language/panel/panel.tsx
@@ -5,7 +5,8 @@ import { Button } from "@/components/ui/button";
import { Checkbox } from "@/components/ui/checkbox";
import { Tooltip, TooltipProvider } from "@/components/ui/tooltip";
import { normalizeName } from "@/core/cells/names";
-import type { ConnectionName } from "@/core/datasets/engines";
+import { getFeatureFlag } from "@/core/config/feature-flag";
+import { type ConnectionName, DUCKDB_ENGINE } from "@/core/datasets/engines";
import { useAutoGrowInputProps } from "@/hooks/useAutoGrowInputProps";
import { formatSQL } from "../../format";
import { languageAdapterState } from "../extension";
@@ -22,7 +23,7 @@ import {
import type { LanguageMetadataOf } from "../types";
import type { QuotePrefixKind } from "../utils/quotes";
import { getQuotePrefix, MarkdownQuotePrefixTooltip } from "./markdown";
-import { SQLEngineSelect } from "./sql";
+import { SQLEngineSelect, SQLModeSelect } from "./sql";
const Divider = () => ;
@@ -70,6 +71,8 @@ export const LanguagePanelComponent: React.FC<{
updateSQLDialectFromConnection(view, engine);
};
+ const sqlModeEnabled = getFeatureFlag("sql_mode");
+
actions = (