Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
457 changes: 319 additions & 138 deletions frontend/src/components/datasources/__tests__/utils.test.ts

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions frontend/src/components/datasources/datasources.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ export const DataSources: React.FC = () => {
databaseName={database.name}
hasSearch={hasSearch}
searchValue={searchValue}
dialect={connection.dialect}
/>
</DatabaseItem>
))}
Expand Down Expand Up @@ -340,6 +341,7 @@ const SchemaList: React.FC<{
schemas: DatabaseSchema[];
defaultSchema?: string | null;
defaultDatabase?: string | null;
dialect: string;
engineName: string;
databaseName: string;
hasSearch: boolean;
Expand All @@ -348,6 +350,7 @@ const SchemaList: React.FC<{
schemas,
defaultSchema,
defaultDatabase,
dialect,
engineName,
databaseName,
hasSearch,
Expand Down Expand Up @@ -384,6 +387,7 @@ const SchemaList: React.FC<{
schema: schema.name,
defaultSchema: defaultSchema,
defaultDatabase: defaultDatabase,
dialect: dialect,
}}
/>
</SchemaItem>
Expand Down
77 changes: 73 additions & 4 deletions frontend/src/components/datasources/utils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
/* Copyright 2024 Marimo. All rights reserved. */

import { BigQueryDialect } from "@marimo-team/codemirror-sql/dialects";
import { isKnownDialect } from "@/core/codemirror/language/languages/sql/utils";
import type { SQLTableContext } from "@/core/datasets/data-source-connections";
import { DUCKDB_ENGINE } from "@/core/datasets/engines";
import type { DataTable, DataType } from "@/core/kernel/messages";
Expand All @@ -9,6 +12,59 @@ export function isSchemaless(schemaName: string) {
return schemaName === "";
}

interface SqlCodeFormatter {
/**
* Format the table name based on dialect-specific rules
*/
formatTableName: (tableName: string) => string;
/**
* Format the SELECT clause
*/
formatSelectClause: (columnName: string, tableName: string) => string;
}

const defaultFormatter: SqlCodeFormatter = {
formatTableName: (tableName: string) => tableName,
formatSelectClause: (columnName: string, tableName: string) =>
`SELECT ${columnName} FROM ${tableName} LIMIT 100`,
};

function getFormatter(dialect: string): SqlCodeFormatter {
dialect = dialect.toLowerCase();
if (!isKnownDialect(dialect)) {
return defaultFormatter;
}

switch (dialect) {
case "bigquery": {
const quote = BigQueryDialect.spec.identifierQuotes;
return {
// BigQuery uses backticks for identifiers
formatTableName: (tableName: string) => `${quote}${tableName}${quote}`,
formatSelectClause: defaultFormatter.formatSelectClause,
};
}
case "mssql":
case "sqlserver":
return {
formatTableName: defaultFormatter.formatTableName,
formatSelectClause: (columnName: string, tableName: string) =>
`SELECT TOP 100 ${columnName} FROM ${tableName}`,
};
case "timescaledb":
return {
// TimescaleDB uses double quotes for identifiers
formatTableName: (tableName: string) => {
const parts = tableName.split(".");
return parts.map((part) => `"${part}"`).join(".");
},
formatSelectClause: defaultFormatter.formatSelectClause,
};
default:
return defaultFormatter;
}
}

export function sqlCode({
table,
columnName,
Expand All @@ -19,8 +75,14 @@ export function sqlCode({
sqlTableContext?: SQLTableContext;
}) {
if (sqlTableContext) {
const { engine, schema, defaultSchema, defaultDatabase, database } =
sqlTableContext;
const {
engine,
schema,
defaultSchema,
defaultDatabase,
database,
dialect,
} = sqlTableContext;
let tableName = table.name;

// Set the fully qualified table name based on schema and database
Expand All @@ -39,11 +101,18 @@ export function sqlCode({
}
}

const formatter = getFormatter(dialect);
const formattedTableName = formatter.formatTableName(tableName);
const selectClause = formatter.formatSelectClause(
columnName,
formattedTableName,
);

if (engine === DUCKDB_ENGINE) {
return `_df = mo.sql(f"SELECT ${columnName} FROM ${tableName} LIMIT 100")`;
return `_df = mo.sql(f"""${selectClause}""")`;
}

return `_df = mo.sql(f"SELECT ${columnName} FROM ${tableName} LIMIT 100", engine=${engine})`;
return `_df = mo.sql(f"""${selectClause}""", engine=${engine})`;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe new lines around the query in case it ends with a "

}

return `_df = mo.sql(f'SELECT "${columnName}" FROM ${table.name} LIMIT 100')`;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class SQLCompletionStore {
if (!connection) {
return ModifiedStandardSQL;
}
return guessDialect(connection) ?? ModifiedStandardSQL;
return guessDialect(connection);
}

getCompletionSource(connectionName: ConnectionName): SQLConfig | null {
Expand All @@ -152,7 +152,7 @@ class SQLCompletionStore {
const schema = this.cache.getOrCreate(connection);

return {
dialect: guessDialect(connection) ?? ModifiedStandardSQL,
dialect: guessDialect(connection),
schema: schema.shouldAddLocalTables
? { ...schema.schema, ...getTablesMap() }
: schema.schema,
Expand Down
6 changes: 6 additions & 0 deletions frontend/src/core/codemirror/language/languages/sql/sql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import {
} from "./completion-sources";
import { SCHEMA_CACHE } from "./completion-store";
import { getSQLMode, type SQLMode } from "./sql-mode";
import { isKnownDialect } from "./utils";

const DEFAULT_DIALECT = DuckDBDialect;
const DEFAULT_PARSER_DIALECT = "DuckDB";
Expand Down Expand Up @@ -353,6 +354,11 @@ function connectionNameToParserDialect(
): ParserDialects | null {
const dialect =
SCHEMA_CACHE.getInternalDialect(connectionName)?.toLowerCase();

if (!dialect || !isKnownDialect(dialect)) {
return null;
}

switch (dialect) {
case "postgresql":
case "postgres":
Expand Down
42 changes: 39 additions & 3 deletions frontend/src/core/codemirror/language/languages/sql/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,49 @@ import {
} from "@marimo-team/codemirror-sql/dialects";
import type { DataSourceConnection } from "@/core/kernel/messages";

const KNOWN_DIALECTS_ARRAY = [
"postgresql",
"postgres",
"db2",
"mysql",
"sqlite",
"mssql",
"sqlserver",
"duckdb",
"mariadb",
"cassandra",
"noql",
"athena",
"bigquery",
"hive",
"redshift",
"snowflake",
"flink",
"mongodb",
"oracle",
"oracledb",
"timescaledb",
] as const;
const KNOWN_DIALECTS: ReadonlySet<string> = new Set(KNOWN_DIALECTS_ARRAY);
type KnownDialect = (typeof KNOWN_DIALECTS_ARRAY)[number];

export function isKnownDialect(dialect: string): dialect is KnownDialect {
return KNOWN_DIALECTS.has(dialect);
}

/**
* Guess the CodeMirror SQL dialect from the backend connection dialect.
* If unknown, return the standard SQL dialect.
*/
export function guessDialect(
connection: Pick<DataSourceConnection, "dialect">,
): SQLDialect | undefined {
switch (connection.dialect) {
): SQLDialect {
const dialect = connection.dialect;
if (!isKnownDialect(dialect)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it better to just fall though the switch-case so this does not get out of sync?

return ModifiedStandardSQL;
}

switch (dialect) {
case "postgresql":
case "postgres":
return PostgreSQL;
Expand All @@ -47,7 +83,7 @@ export function guessDialect(
case "bigquery":
return BigQueryDialect;
default:
return undefined;
return ModifiedStandardSQL;
}
}

Expand Down
6 changes: 6 additions & 0 deletions frontend/src/core/datasets/__tests__/data-source.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ describe("add table list", () => {
engine: "conn1" as ConnectionName,
database: "db1",
schema: "public",
dialect: "sqlite",
});

const conn1 = newState.connectionsMap.get("conn1" as ConnectionName);
Expand All @@ -283,6 +284,7 @@ describe("add table list", () => {
engine: "conn1" as ConnectionName,
database: "db1",
schema: "public",
dialect: "sqlite",
};

const tableList: DataTable[] = [
Expand Down Expand Up @@ -344,6 +346,7 @@ describe("add table list", () => {
engine: "conn1" as ConnectionName,
database: "db1",
schema: "non_existent",
dialect: "sqlite",
});

const conn1 = newState.connectionsMap.get("conn1" as ConnectionName);
Expand Down Expand Up @@ -407,6 +410,7 @@ describe("add table", () => {
engine: "conn1" as ConnectionName,
database: "db1",
schema: "public",
dialect: "sqlite",
});

const conn1 = newState.connectionsMap.get("conn1" as ConnectionName);
Expand All @@ -420,6 +424,7 @@ describe("add table", () => {
engine: "conn1" as ConnectionName,
database: "db1",
schema: "public",
dialect: "sqlite",
};

const table: DataTable = {
Expand Down Expand Up @@ -476,6 +481,7 @@ describe("add table", () => {
engine: "conn1" as ConnectionName,
database: "db1",
schema: "non_existent",
dialect: "sqlite",
});

const conn1 = newState.connectionsMap.get("conn1" as ConnectionName);
Expand Down
1 change: 1 addition & 0 deletions frontend/src/core/datasets/data-source-connections.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export interface SQLTableContext {
engine: string;
database: string;
schema: string;
dialect: string;
defaultSchema?: string | null;
defaultDatabase?: string | null;
}
Expand Down
Loading