Skip to content

Commit a37b800

Browse files
authored
fix: add more expressions for SQL refs (#6538)
Fixes #6533 This adds more SQL expressions that capture refs. Bumps sqlglot min dep to a version from Jan 2025
1 parent 004c2d5 commit a37b800

File tree

4 files changed

+78
-9
lines changed

4 files changed

+78
-9
lines changed

marimo/_ast/sql_visitor.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def print_part(part: ast.expr) -> str:
9898
return result
9999

100100

101-
class TokenExtractor:
101+
class _TokenExtractor:
102102
def __init__(self, sql_statement: str, tokens: list[Any]) -> None:
103103
self.sql_statement = sql_statement
104104
self.tokens = tokens
@@ -178,7 +178,7 @@ def find_sql_defs(sql_statement: str) -> SQLDefs:
178178
import duckdb
179179

180180
tokens = duckdb.tokenize(sql_statement)
181-
token_extractor = TokenExtractor(
181+
token_extractor = _TokenExtractor(
182182
sql_statement=sql_statement, tokens=tokens
183183
)
184184
created_tables: list[SQLRef] = []
@@ -529,7 +529,20 @@ def get_ref_from_table(table: exp.Table) -> Optional[SQLRef]:
529529
if expression is None:
530530
continue
531531

532-
if bool(expression.find(exp.Update, exp.Insert, exp.Delete)):
532+
if bool(
533+
expression.find(
534+
exp.Update,
535+
exp.Insert,
536+
exp.Delete,
537+
exp.Describe,
538+
exp.Summarize,
539+
exp.Pivot,
540+
exp.Analyze,
541+
exp.Drop,
542+
exp.TruncateTable,
543+
exp.Copy,
544+
)
545+
):
533546
for table in expression.find_all(exp.Table):
534547
if ref := get_ref_from_table(table):
535548
refs.add(ref)

pyproject.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ marimo_converter = "marimo._utils.docs:MarimoConverter"
8484
sql = [
8585
"duckdb>=1.0.0",
8686
"polars[pyarrow]>=1.9.0",
87-
"sqlglot>=25.32"
87+
"sqlglot>=26.2.0"
8888
]
8989

9090
# List of deps that are recommended for most users
@@ -94,7 +94,7 @@ recommended = [
9494
"duckdb>=1.0.0", # SQL cells
9595
"altair>=5.4.0", # Plotting in datasource viewer
9696
"polars[pyarrow]>=1.9.0", # SQL output back in Python
97-
"sqlglot>=25.32", # SQL cells parsing
97+
"sqlglot>=26.2.0", # SQL cells parsing
9898
"openai>=1.55.3", # AI features
9999
"ruff", # Formatting
100100
"nbformat>=5.7.0", # Export as IPYNB
@@ -108,7 +108,7 @@ dev = [
108108
"opentelemetry-sdk~=1.26.0",
109109
# For SQL
110110
"duckdb>=1.0.0",
111-
"sqlglot>=25.32",
111+
"sqlglot>=26.2.0",
112112
# For linting
113113
"ruff~=0.9.1",
114114
# For AI
@@ -181,7 +181,7 @@ dependencies = [
181181
# Types in 2.2.0 don't work great with mypy
182182
"narwhals>=1.34.1, <2.2.0",
183183
"matplotlib>=3.8.0",
184-
"sqlglot>=25.32",
184+
"sqlglot>=26.2.0",
185185
"sqlalchemy>=2.0.40",
186186
"google-genai>=1.0.0",
187187
"openai>=1.55.3",
@@ -263,7 +263,7 @@ extra-dependencies = [
263263
"altair>=5.4.0",
264264
"plotly>=5.14.0",
265265
"polars>=1.32.2",
266-
"sqlglot>=25.32",
266+
"sqlglot>=26.2.0",
267267
"sqlalchemy>=2.0.40",
268268
"pyiceberg>=0.9.0",
269269
# For testing clickhouse

tests/_ast/test_sql_visitor.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,3 +841,59 @@ def test_read_file_and_urls(self) -> None:
841841

842842
sql = "SELECT * FROM read_csv('/dev/stdin')"
843843
assert find_sql_refs(sql) == set()
844+
845+
def test_describe_table(self) -> None:
846+
sql = "DESCRIBE test_duck;"
847+
assert find_sql_refs(sql) == {SQLRef(table="test_duck")}
848+
849+
def test_summarize_table(self) -> None:
850+
sql = "SUMMARIZE test_duck;"
851+
assert find_sql_refs(sql) == {SQLRef(table="test_duck")}
852+
853+
def test_pivot_table(self) -> None:
854+
sql = "PIVOT test_duck ON column_name USING sum(value);"
855+
assert find_sql_refs(sql) == {SQLRef(table="test_duck")}
856+
857+
def test_unpivot_table(self) -> None:
858+
sql = "UNPIVOT test_duck ON (col1, col2) INTO NAME column_name VALUE column_value;"
859+
assert find_sql_refs(sql) == {SQLRef(table="test_duck")}
860+
861+
def test_describe_with_schema(self) -> None:
862+
sql = "DESCRIBE my_schema.test_table;"
863+
assert find_sql_refs(sql) == {
864+
SQLRef(table="test_table", schema="my_schema")
865+
}
866+
867+
def test_summarize_with_catalog_schema(self) -> None:
868+
sql = "SUMMARIZE my_catalog.my_schema.test_table;"
869+
assert find_sql_refs(sql) == {
870+
SQLRef(
871+
table="test_table", schema="my_schema", catalog="my_catalog"
872+
)
873+
}
874+
875+
def test_analyze_table(self) -> None:
876+
sql = "ANALYZE test_table;"
877+
assert find_sql_refs(sql) == {SQLRef(table="test_table")}
878+
879+
def test_drop_table(self) -> None:
880+
sql = "DROP TABLE test_table;"
881+
assert find_sql_refs(sql) == {SQLRef(table="test_table")}
882+
883+
def test_truncate_table(self) -> None:
884+
sql = "TRUNCATE test_table;"
885+
assert find_sql_refs(sql) == {SQLRef(table="test_table")}
886+
887+
def test_copy_table_to_file(self) -> None:
888+
sql = "COPY test_table TO 'output.csv';"
889+
assert find_sql_refs(sql) == {SQLRef(table="test_table")}
890+
891+
def test_copy_table_from_file(self) -> None:
892+
sql = "COPY test_table FROM 'input.csv';"
893+
assert find_sql_refs(sql) == {SQLRef(table="test_table")}
894+
895+
def test_analyze_with_schema(self) -> None:
896+
sql = "ANALYZE my_schema.test_table;"
897+
assert find_sql_refs(sql) == {
898+
SQLRef(table="test_table", schema="my_schema")
899+
}

tests/_ast/test_visitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1435,7 +1435,7 @@ def test_sql_table_f_string() -> None:
14351435
"drop table and reference it in same statement",
14361436
"drop table schema.cars; select * from cars",
14371437
set(),
1438-
{"mo", "cars"},
1438+
{"mo", "cars", "schema.cars"},
14391439
),
14401440
(
14411441
"drop schema and reference table from it in same statement",

0 commit comments

Comments
 (0)