Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ export function clearSqlValidationError(cellId: CellId) {
store.set(sqlValidationErrorsAtom, newErrors);
}

export function clearAllSqlValidationErrors() {
store.set(sqlValidationErrorsAtom, new Map<CellId, SQLValidationError>());
}

export function setSqlValidationError({
cellId,
errorMessage,
Expand Down
70 changes: 42 additions & 28 deletions frontend/src/core/codemirror/language/languages/sql/sql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import {
INTERNAL_SQL_ENGINES,
} from "@/core/datasets/engines";
import { ValidateSQL } from "@/core/datasets/request-registry";
import type { ValidateSQLResult } from "@/core/kernel/messages";
import { store } from "@/core/state/jotai";
import { resolvedThemeAtom } from "@/theme/useTheme";
import { Logger } from "@/utils/Logger";
Expand All @@ -49,7 +50,7 @@ import {
tablesCompletionSource,
} from "./completion-sources";
import { SCHEMA_CACHE } from "./completion-store";
import { getSQLMode } from "./sql-mode";
import { getSQLMode, type SQLMode } from "./sql-mode";

const DEFAULT_DIALECT = DuckDBDialect;
const DEFAULT_PARSER_DIALECT = "DuckDB";
Expand Down Expand Up @@ -307,12 +308,8 @@ class CustomSqlParser extends NodeSqlParser {
return new Promise((resolve) => {
this.validationTimeout = window.setTimeout(async () => {
try {
const result = await ValidateSQL.request({
onlyParse: true,
engine,
dialect,
query: sql,
});
const sqlMode = getSQLMode();
const result = await validateSQL(sql, engine, dialect, sqlMode);
if (result.error) {
Logger.error("Failed to validate SQL", { error: result.error });
resolve([]);
Expand All @@ -333,7 +330,7 @@ class CustomSqlParser extends NodeSqlParser {
): Promise<SqlParseError[]> {
const metadata = getSQLMetadata(opts.state);

// Only perform custom validation for internal engines
// Only perform custom validation for DuckDB
if (!INTERNAL_SQL_ENGINES.has(metadata.engine)) {
return super.validateSql(sql, opts);
}
Expand Down Expand Up @@ -640,6 +637,9 @@ function safeDedent(code: string): string {

const SQL_VALIDATION_DEBOUNCE_MS = 300;

/**
* Custom extension to run SQL queries in EXPLAIN mode on keypress.
*/
function sqlValidationExtension(): Extension {
let debounceTimeout: number | undefined;
let lastValidationRequest: string | null = null;
Expand All @@ -649,14 +649,10 @@ function sqlValidationExtension(): Extension {
return;
}

const sqlMode = getSQLMode();
if (sqlMode !== "validate") {
return;
}

const metadata = getSQLMetadata(update.state);
const connectionName = metadata.engine;
// Currently only internal engines are supported

// Currently only DuckDB is supported
if (!INTERNAL_SQL_ENGINES.has(connectionName)) {
return;
}
Expand Down Expand Up @@ -685,22 +681,17 @@ function sqlValidationExtension(): Extension {
}

try {
const result = await ValidateSQL.request({
onlyParse: false,
dialect: connectionNameToParserDialect(connectionName),
engine: connectionName,
query: sqlContent,
});

if (result.error) {
Logger.error("Failed to validate SQL", { error: result.error });
return;
}

const dialect = connectionNameToParserDialect(connectionName);
const sqlMode = getSQLMode();
const result = await validateSQL(
sqlContent,
connectionName,
dialect,
sqlMode,
);
const validateResult = result.validate_result;

if (validateResult?.error_message) {
const dialect = connectionNameToParserDialect(connectionName);
setSqlValidationError({
cellId,
errorMessage: validateResult.error_message,
Expand All @@ -710,8 +701,31 @@ function sqlValidationExtension(): Extension {
clearSqlValidationError(cellId);
}
} catch (error) {
Logger.warn("Failed to validate SQL", { error });
Logger.error("Failed to validate SQL", { error });
}
}, SQL_VALIDATION_DEBOUNCE_MS);
});
}

/**
* Determine if we should only parse or validate an SQL query.
* The endpoint is cached, so we should use the same mode for all validation requests.
*/
async function validateSQL(
sql: string,
engine: string,
dialect: ParserDialects | null,
sqlMode: SQLMode,
): Promise<ValidateSQLResult> {
const result = await ValidateSQL.request({
onlyParse: sqlMode === "default",
engine,
dialect,
query: sql,
});

if (result.error) {
throw new Error(result.error);
}
return result;
}
7 changes: 6 additions & 1 deletion frontend/src/core/codemirror/language/panel/sql.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import {
} from "@/core/datasets/engines";
import type { DataSourceConnection } from "@/core/kernel/messages";
import { useNonce } from "@/hooks/useNonce";
import { clearAllSqlValidationErrors } from "../languages/sql/banner-validation-errors";
import { type SQLMode, useSQLMode } from "../languages/sql/sql-mode";

interface SelectProps {
Expand Down Expand Up @@ -144,7 +145,11 @@ export const SQLModeSelect: React.FC = () => {
const { sqlMode, setSQLMode } = useSQLMode();

const handleToggleMode = () => {
setSQLMode(sqlMode === "validate" ? "default" : "validate");
const nextMode = sqlMode === "validate" ? "default" : "validate";
if (nextMode === "default") {
clearAllSqlValidationErrors();
}
setSQLMode(nextMode);
};

const getModeIcon = (mode: SQLMode) => {
Expand Down
16 changes: 13 additions & 3 deletions marimo/_runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2491,8 +2491,8 @@ def __init__(self, kernel: Kernel):
async def _validate_sql_query(self, request: ValidateSQLRequest) -> None:
"""Validate an SQL query

This will either validate:
- the syntax (parsing) or
This will validate:
- the syntax (parsing)
- the catalog (table and column names)
"""
request_id = request.request_id
Expand Down Expand Up @@ -2541,10 +2541,17 @@ async def _validate_sql_query(self, request: ValidateSQLRequest) -> None:
).broadcast()
return

# Get the parse error for linting
parse_result, parse_error = parse_sql(request.query, engine.dialect)
if parse_error is not None:
# We don't want to fail the validation if there is a parse error
LOGGER.debug("Parse error: %s", parse_error)

if not isinstance(engine, QueryEngine):
ValidateSQLResult(
request_id=request_id,
error=f"Engine {variable_name} does not support catalog validation.",
parse_result=parse_result,
).broadcast()
return

Expand All @@ -2554,7 +2561,10 @@ async def _validate_sql_query(self, request: ValidateSQLRequest) -> None:
error_message=error_message,
)
ValidateSQLResult(
request_id=request_id, validate_result=validate_result, error=None
request_id=request_id,
validate_result=validate_result,
parse_result=parse_result,
error=None,
).broadcast()

@kernel_tracer.start_as_current_span("validate_sql")
Expand Down
9 changes: 4 additions & 5 deletions tests/_runtime/test_runtime_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ async def test_internal_engine_and_valid_query(
]
assert validate_sql_results[-1] == ValidateSQLResult(
request_id=RequestId("1"),
parse_result=None,
parse_result=SqlParseResult(success=True, errors=[]),
validate_result=SqlCatalogCheckResult(
success=True, error_message=None
),
Expand All @@ -365,11 +365,10 @@ async def test_internal_engine_and_invalid_query(
latest_validate_sql_result = validate_sql_results[-1]
assert latest_validate_sql_result.request_id == RequestId("2")

# if not only_parse, parse_result is not None
assert latest_validate_sql_result.parse_result is None
assert latest_validate_sql_result.parse_result is not None
# query is syntactically valid
# assert latest_validate_sql_result.parse_result.success is True
# assert len(latest_validate_sql_result.parse_result.errors) == 0
assert latest_validate_sql_result.parse_result.success is True
assert len(latest_validate_sql_result.parse_result.errors) == 0

assert latest_validate_sql_result.validate_result is not None
assert latest_validate_sql_result.validate_result.success is False
Expand Down
Loading