Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"@marimo-team/codemirror-ai": "^0.3.2",
"@marimo-team/codemirror-languageserver": "1.15.24",
"@marimo-team/codemirror-mcp": "^0.1.5",
"@marimo-team/codemirror-sql": "^0.2.1",
"@marimo-team/codemirror-sql": "^0.2.3",
"@marimo-team/llm-info": "workspace:*",
"@marimo-team/marimo-api": "workspace:*",
"@marimo-team/react-slotz": "^0.2.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import { AlertCircleIcon } from "lucide-react";
import type { CellId } from "@/core/cells/ids";
import { useSqlValidationErrorsForCell } from "@/core/codemirror/language/languages/sql/validation-errors";
import { useSqlValidationErrorsForCell } from "@/core/codemirror/language/languages/sql/banner-validation-errors";

export const SqlValidationErrorBanner = ({
cellId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {
languageAdapterState,
switchLanguage,
} from "../extension";
import { exportedForTesting as sqlValidationErrorsForTesting } from "../languages/sql/validation-errors";
import { exportedForTesting as sqlValidationErrorsForTesting } from "../languages/sql/banner-validation-errors";
import { languageMetadataField } from "../metadata";

let view: EditorView | null = null;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* Copyright 2024 Marimo. All rights reserved. */

import { describe, expect, it } from "vitest";
import { exportedForTesting } from "../languages/sql/validation-errors";
import { exportedForTesting } from "../languages/sql/banner-validation-errors";

describe("Error Message Splitting", () => {
it("should handle error message splitting correctly", () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,20 @@ export function clearSqlValidationError(cellId: CellId) {

export function setSqlValidationError({
cellId,
error,
errorMessage,
dialect,
}: {
cellId: CellId;
error: string;
errorMessage: string;
dialect: SupportedDialects | null;
}) {
const sqlValidationErrors = store.get(sqlValidationErrorsAtom);
const newErrors = new Map(sqlValidationErrors);

const errorResult: SQLValidationError =
dialect === "DuckDB" ? handleDuckdbError(error) : splitErrorMessage(error);
dialect === "DuckDB"
? handleDuckdbError(errorMessage)
: splitErrorMessage(errorMessage);

newErrors.set(cellId, errorResult);
store.set(sqlValidationErrorsAtom, newErrors);
Expand Down
128 changes: 107 additions & 21 deletions frontend/src/core/codemirror/language/languages/sql/sql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ import { parser } from "@lezer/python";
import {
defaultSqlHoverTheme,
NodeSqlParser,
type NodeSqlParserResult,
type SupportedDialects as ParserDialects,
type SqlParseError,
sqlExtension,
} from "@marimo-team/codemirror-sql";
import { DuckDBDialect } from "@marimo-team/codemirror-sql/dialects";
Expand All @@ -28,7 +30,6 @@ import {
INTERNAL_SQL_ENGINES,
} from "@/core/datasets/engines";
import { ValidateSQL } from "@/core/datasets/request-registry";
import type { ValidateSQLResult } from "@/core/kernel/messages";
import { store } from "@/core/state/jotai";
import { resolvedThemeAtom } from "@/theme/useTheme";
import { Logger } from "@/utils/Logger";
Expand All @@ -39,16 +40,16 @@ import { parseArgsKwargs } from "../../utils/ast";
import { indentOneTab } from "../../utils/indentOneTab";
import type { QuotePrefixKind } from "../../utils/quotes";
import { MarkdownLanguageAdapter } from "../markdown";
import {
clearSqlValidationError,
setSqlValidationError,
} from "./banner-validation-errors";
import {
customKeywordCompletionSource,
tablesCompletionSource,
} from "./completion-sources";
import { SCHEMA_CACHE } from "./completion-store";
import { getSQLMode } from "./sql-mode";
import {
clearSqlValidationError,
setSqlValidationError,
} from "./validation-errors";

const DEFAULT_DIALECT = DuckDBDialect;
const DEFAULT_PARSER_DIALECT = "DuckDB";
Expand Down Expand Up @@ -245,7 +246,7 @@ export class SQLLanguageAdapter

if (this.sqlLinterEnabled) {
const theme = store.get(resolvedThemeAtom);
const parser = new NodeSqlParser({
const parser = new CustomSqlParser({
getParserOptions: (state: EditorState) => {
return {
database: guessParserDialect(state) ?? DEFAULT_PARSER_DIALECT,
Expand Down Expand Up @@ -280,14 +281,84 @@ export class SQLLanguageAdapter
);
}

if (this.sqlModeEnabled) {
extensions.push(sqlValidationExtension());
}
// TODO: Re-enable after we optimize the endpoint
// if (this.sqlModeEnabled) {
// extensions.push(sqlValidationExtension());
// }

return extensions;
}
}

class CustomSqlParser extends NodeSqlParser {
private validationTimeout: number | null = null;
private readonly VALIDATION_DELAY_MS = 300; // Wait 300ms after user stops typing

private async validateWithDelay(
sql: string,
engine: string,
dialect: ParserDialects | null,
): Promise<SqlParseError[]> {
// Clear any existing delay call
if (this.validationTimeout) {
window.clearTimeout(this.validationTimeout);
}

// Set up a new request to be called after the delay
return new Promise((resolve) => {
this.validationTimeout = window.setTimeout(async () => {
try {
const result = await ValidateSQL.request({
onlyParse: true,
engine,
dialect,
query: sql,
});
if (result.error) {
Logger.error("Failed to validate SQL", { error: result.error });
resolve([]);
return;
}
resolve(result.parse_result?.errors ?? []);
} catch (error) {
Logger.error("Failed to validate SQL", { error });
resolve([]);
}
}, this.VALIDATION_DELAY_MS);
});
}

override async validateSql(
sql: string,
opts: { state: EditorState },
): Promise<SqlParseError[]> {
const metadata = getSQLMetadata(opts.state);

// Only perform custom validation for internal engines
if (!INTERNAL_SQL_ENGINES.has(metadata.engine)) {
return super.validateSql(sql, opts);
}

const dialect = guessParserDialect(opts.state);
return this.validateWithDelay(sql, metadata.engine, dialect);
}

override async parse(
sql: string,
opts: { state: EditorState },
): Promise<NodeSqlParserResult> {
const metadata = getSQLMetadata(opts.state);
const engine = metadata.engine;

// For now, always return success for DuckDB
if (engine === DUCKDB_ENGINE) {
return { success: true, errors: [] };
}

return super.parse(sql, opts);
}
}

/**
* Update the SQL dialect in the editor view.
*/
Expand Down Expand Up @@ -568,24 +639,27 @@ function safeDedent(code: string): string {
}
}

function sqlValidationExtension(): Extension {
let debounceTimeout: NodeJS.Timeout | null = null;
const SQL_VALIDATION_DEBOUNCE_MS = 300;

// @ts-expect-error: TODO: Re-enable after we optimize the endpoint
function _sqlValidationExtension(): Extension {
let debounceTimeout: number | undefined;
let lastValidationRequest: string | null = null;

return EditorView.updateListener.of((update) => {
if (!update.docChanged) {
return;
}

const sqlMode = getSQLMode();
if (sqlMode !== "validate") {
return;
}

const metadata = getSQLMetadata(update.state);
const connectionName = metadata.engine;
// Currently only internal engines are supported
if (!INTERNAL_SQL_ENGINES.has(connectionName)) {
// Currently only internal engines are supported
return;
}

if (!update.docChanged) {
return;
}

Expand All @@ -594,11 +668,11 @@ function sqlValidationExtension(): Extension {

// Clear existing timeout
if (debounceTimeout) {
clearTimeout(debounceTimeout);
window.clearTimeout(debounceTimeout);
}

// Debounce the validation call
debounceTimeout = setTimeout(async () => {
debounceTimeout = window.setTimeout(async () => {
// Skip if content hasn't changed since last validation
if (lastValidationRequest === sqlContent) {
return;
Expand All @@ -613,20 +687,32 @@ function sqlValidationExtension(): Extension {
}

try {
const result: ValidateSQLResult = await ValidateSQL.request({
const result = await ValidateSQL.request({
onlyParse: false,
engine: connectionName,
query: sqlContent,
});

if (result.error) {
Logger.error("Failed to validate SQL", { error: result.error });
return;
}

const validateResult = result.validate_result;

if (validateResult?.error_message) {
const dialect = connectionNameToParserDialect(connectionName);
setSqlValidationError({ cellId, error: result.error, dialect });
setSqlValidationError({
cellId,
errorMessage: validateResult.error_message,
dialect,
});
} else {
clearSqlValidationError(cellId);
}
} catch (error) {
Logger.warn("Failed to validate SQL", { error });
}
}, 300);
}, SQL_VALIDATION_DEBOUNCE_MS);
});
}
2 changes: 1 addition & 1 deletion marimo/_config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ class ExperimentalConfig(TypedDict, total=False):
performant_table_charts: bool
mcp_docs: bool
sql_linter: bool
sql_mode: bool
sql_mode: bool # Not exposed for now

# Internal features
cache: CacheConfig
Expand Down
4 changes: 3 additions & 1 deletion marimo/_messaging/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from marimo._runtime.context.utils import get_mode
from marimo._runtime.layout.layout import LayoutConfig
from marimo._secrets.models import SecretKeysWithProvider
from marimo._sql.parse import SqlCatalogCheckResult, SqlParseResult
from marimo._types.ids import CellId_t, RequestId, WidgetModelId
from marimo._utils.platform import is_pyodide, is_windows

Expand Down Expand Up @@ -629,7 +630,8 @@ class DataSourceConnections(Op, tag="data-source-connections"):
class ValidateSQLResult(Op, tag="validate-sql-result"):
name: ClassVar[str] = "validate-sql-result"
request_id: RequestId
result: Optional[Any] = None
parse_result: Optional[SqlParseResult] = None
validate_result: Optional[SqlCatalogCheckResult] = None
error: Optional[str] = None


Expand Down
16 changes: 14 additions & 2 deletions marimo/_runtime/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,22 @@ class PreviewDataSourceConnectionRequest(msgspec.Struct, rename="camel"):


class ValidateSQLRequest(msgspec.Struct, rename="camel"):
"""Validate an SQL query"""
"""Validate an SQL query against the engine"""

request_id: RequestId
engine: str
query: str
# Whether to only parse the query or validate against the database
# Parsing is done without a DB connection and uses dialect, whereas validation requires a connection
only_parse: bool
engine: Optional[str] = None
dialect: Optional[str] = None


class ParseSQLRequest(msgspec.Struct, rename="camel"):
"""Parse an SQL query for linting"""

request_id: RequestId
dialect: str
query: str


Expand Down
Loading
Loading