Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions marimo/_ast/sql_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def print_part(part: ast.expr) -> str:
return result


class TokenExtractor:
class _TokenExtractor:
def __init__(self, sql_statement: str, tokens: list[Any]) -> None:
self.sql_statement = sql_statement
self.tokens = tokens
Expand Down Expand Up @@ -178,7 +178,7 @@ def find_sql_defs(sql_statement: str) -> SQLDefs:
import duckdb

tokens = duckdb.tokenize(sql_statement)
token_extractor = TokenExtractor(
token_extractor = _TokenExtractor(
sql_statement=sql_statement, tokens=tokens
)
created_tables: list[SQLRef] = []
Expand Down Expand Up @@ -529,7 +529,20 @@ def get_ref_from_table(table: exp.Table) -> Optional[SQLRef]:
if expression is None:
continue

if bool(expression.find(exp.Update, exp.Insert, exp.Delete)):
if bool(
expression.find(
exp.Update,
exp.Insert,
exp.Delete,
exp.Describe,
exp.Summarize,
exp.Pivot,
exp.Analyze,
exp.Drop,
exp.TruncateTable,
exp.Copy,
)
):
for table in expression.find_all(exp.Table):
if ref := get_ref_from_table(table):
refs.add(ref)
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ marimo_converter = "marimo._utils.docs:MarimoConverter"
sql = [
"duckdb>=1.0.0",
"polars[pyarrow]>=1.9.0",
"sqlglot>=25.32"
"sqlglot>=26.2.0"
]

# List of deps that are recommended for most users
Expand All @@ -94,7 +94,7 @@ recommended = [
"duckdb>=1.0.0", # SQL cells
"altair>=5.4.0", # Plotting in datasource viewer
"polars[pyarrow]>=1.9.0", # SQL output back in Python
"sqlglot>=25.32", # SQL cells parsing
"sqlglot>=26.2.0", # SQL cells parsing
"openai>=1.55.3", # AI features
"ruff", # Formatting
"nbformat>=5.7.0", # Export as IPYNB
Expand All @@ -108,7 +108,7 @@ dev = [
"opentelemetry-sdk~=1.26.0",
# For SQL
"duckdb>=1.0.0",
"sqlglot>=25.32",
"sqlglot>=26.2.0",
# For linting
"ruff~=0.9.1",
# For AI
Expand Down Expand Up @@ -180,7 +180,7 @@ dependencies = [
# Types in 2.2.0 don't work great with mypy
"narwhals>=1.34.1, <2.2.0",
"matplotlib>=3.8.0",
"sqlglot>=25.32",
"sqlglot>=26.2.0",
"sqlalchemy>=2.0.40",
"google-genai>=1.0.0",
"openai>=1.55.3",
Expand Down Expand Up @@ -262,7 +262,7 @@ extra-dependencies = [
"altair>=5.4.0",
"plotly>=5.14.0",
"polars>=1.32.2",
"sqlglot>=25.32",
"sqlglot>=26.2.0",
"sqlalchemy>=2.0.40",
"pyiceberg>=0.9.0",
# For testing clickhouse
Expand Down
56 changes: 56 additions & 0 deletions tests/_ast/test_sql_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,3 +841,59 @@ def test_read_file_and_urls(self) -> None:

sql = "SELECT * FROM read_csv('/dev/stdin')"
assert find_sql_refs(sql) == set()

def test_describe_table(self) -> None:
sql = "DESCRIBE test_duck;"
assert find_sql_refs(sql) == {SQLRef(table="test_duck")}

def test_summarize_table(self) -> None:
sql = "SUMMARIZE test_duck;"
assert find_sql_refs(sql) == {SQLRef(table="test_duck")}

def test_pivot_table(self) -> None:
sql = "PIVOT test_duck ON column_name USING sum(value);"
assert find_sql_refs(sql) == {SQLRef(table="test_duck")}

def test_unpivot_table(self) -> None:
sql = "UNPIVOT test_duck ON (col1, col2) INTO NAME column_name VALUE column_value;"
assert find_sql_refs(sql) == {SQLRef(table="test_duck")}

def test_describe_with_schema(self) -> None:
sql = "DESCRIBE my_schema.test_table;"
assert find_sql_refs(sql) == {
SQLRef(table="test_table", schema="my_schema")
}

def test_summarize_with_catalog_schema(self) -> None:
sql = "SUMMARIZE my_catalog.my_schema.test_table;"
assert find_sql_refs(sql) == {
SQLRef(
table="test_table", schema="my_schema", catalog="my_catalog"
)
}

def test_analyze_table(self) -> None:
sql = "ANALYZE test_table;"
assert find_sql_refs(sql) == {SQLRef(table="test_table")}

def test_drop_table(self) -> None:
sql = "DROP TABLE test_table;"
assert find_sql_refs(sql) == {SQLRef(table="test_table")}

def test_truncate_table(self) -> None:
sql = "TRUNCATE test_table;"
assert find_sql_refs(sql) == {SQLRef(table="test_table")}

def test_copy_table_to_file(self) -> None:
sql = "COPY test_table TO 'output.csv';"
assert find_sql_refs(sql) == {SQLRef(table="test_table")}

def test_copy_table_from_file(self) -> None:
sql = "COPY test_table FROM 'input.csv';"
assert find_sql_refs(sql) == {SQLRef(table="test_table")}

def test_analyze_with_schema(self) -> None:
sql = "ANALYZE my_schema.test_table;"
assert find_sql_refs(sql) == {
SQLRef(table="test_table", schema="my_schema")
}
2 changes: 1 addition & 1 deletion tests/_ast/test_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,7 +1435,7 @@ def test_sql_table_f_string() -> None:
"drop table and reference it in same statement",
"drop table schema.cars; select * from cars",
set(),
{"mo", "cars"},
{"mo", "cars", "schema.cars"},
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dmadisetti , @Light2Dark , does this change look right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it looks right.

Copy link
Contributor

@Light2Dark Light2Dark Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see a lot more failing tests, will check.
edit: oh, only fails on minimal deps, maybe old duckdb Tokenizer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

 module 'sqlglot.expressions' has no attribute 'Analyze'

im going to see if i can bump the minimum

),
(
"drop schema and reference table from it in same statement",
Expand Down
Loading