diff --git a/marimo/_ast/sql_visitor.py b/marimo/_ast/sql_visitor.py index bd13cc982e7..27e7c0598fe 100644 --- a/marimo/_ast/sql_visitor.py +++ b/marimo/_ast/sql_visitor.py @@ -98,7 +98,7 @@ def print_part(part: ast.expr) -> str: return result -class TokenExtractor: +class _TokenExtractor: def __init__(self, sql_statement: str, tokens: list[Any]) -> None: self.sql_statement = sql_statement self.tokens = tokens @@ -178,7 +178,7 @@ def find_sql_defs(sql_statement: str) -> SQLDefs: import duckdb tokens = duckdb.tokenize(sql_statement) - token_extractor = TokenExtractor( + token_extractor = _TokenExtractor( sql_statement=sql_statement, tokens=tokens ) created_tables: list[SQLRef] = [] @@ -529,7 +529,20 @@ def get_ref_from_table(table: exp.Table) -> Optional[SQLRef]: if expression is None: continue - if bool(expression.find(exp.Update, exp.Insert, exp.Delete)): + if bool( + expression.find( + exp.Update, + exp.Insert, + exp.Delete, + exp.Describe, + exp.Summarize, + exp.Pivot, + exp.Analyze, + exp.Drop, + exp.TruncateTable, + exp.Copy, + ) + ): for table in expression.find_all(exp.Table): if ref := get_ref_from_table(table): refs.add(ref) diff --git a/pyproject.toml b/pyproject.toml index 293a0cc28d8..e349a928f71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,7 @@ marimo_converter = "marimo._utils.docs:MarimoConverter" sql = [ "duckdb>=1.0.0", "polars[pyarrow]>=1.9.0", - "sqlglot>=25.32" + "sqlglot>=26.2.0" ] # List of deps that are recommended for most users @@ -94,7 +94,7 @@ recommended = [ "duckdb>=1.0.0", # SQL cells "altair>=5.4.0", # Plotting in datasource viewer "polars[pyarrow]>=1.9.0", # SQL output back in Python - "sqlglot>=25.32", # SQL cells parsing + "sqlglot>=26.2.0", # SQL cells parsing "openai>=1.55.3", # AI features "ruff", # Formatting "nbformat>=5.7.0", # Export as IPYNB @@ -108,7 +108,7 @@ dev = [ "opentelemetry-sdk~=1.26.0", # For SQL "duckdb>=1.0.0", - "sqlglot>=25.32", + "sqlglot>=26.2.0", # For linting "ruff~=0.9.1", # For AI @@ -180,7 +180,7 @@ dependencies = [ # Types in 2.2.0 don't work great with mypy "narwhals>=1.34.1, <2.2.0", "matplotlib>=3.8.0", - "sqlglot>=25.32", + "sqlglot>=26.2.0", "sqlalchemy>=2.0.40", "google-genai>=1.0.0", "openai>=1.55.3", @@ -262,7 +262,7 @@ extra-dependencies = [ "altair>=5.4.0", "plotly>=5.14.0", "polars>=1.32.2", - "sqlglot>=25.32", + "sqlglot>=26.2.0", "sqlalchemy>=2.0.40", "pyiceberg>=0.9.0", # For testing clickhouse diff --git a/tests/_ast/test_sql_visitor.py b/tests/_ast/test_sql_visitor.py index cb2023fe8b3..81e04328c49 100644 --- a/tests/_ast/test_sql_visitor.py +++ b/tests/_ast/test_sql_visitor.py @@ -841,3 +841,59 @@ def test_read_file_and_urls(self) -> None: sql = "SELECT * FROM read_csv('/dev/stdin')" assert find_sql_refs(sql) == set() + + def test_describe_table(self) -> None: + sql = "DESCRIBE test_duck;" + assert find_sql_refs(sql) == {SQLRef(table="test_duck")} + + def test_summarize_table(self) -> None: + sql = "SUMMARIZE test_duck;" + assert find_sql_refs(sql) == {SQLRef(table="test_duck")} + + def test_pivot_table(self) -> None: + sql = "PIVOT test_duck ON column_name USING sum(value);" + assert find_sql_refs(sql) == {SQLRef(table="test_duck")} + + def test_unpivot_table(self) -> None: + sql = "UNPIVOT test_duck ON (col1, col2) INTO NAME column_name VALUE column_value;" + assert find_sql_refs(sql) == {SQLRef(table="test_duck")} + + def test_describe_with_schema(self) -> None: + sql = "DESCRIBE my_schema.test_table;" + assert find_sql_refs(sql) == { + SQLRef(table="test_table", schema="my_schema") + } + + def test_summarize_with_catalog_schema(self) -> None: + sql = "SUMMARIZE my_catalog.my_schema.test_table;" + assert find_sql_refs(sql) == { + SQLRef( + table="test_table", schema="my_schema", catalog="my_catalog" + ) + } + + def test_analyze_table(self) -> None: + sql = "ANALYZE test_table;" + assert find_sql_refs(sql) == {SQLRef(table="test_table")} + + def test_drop_table(self) -> None: + sql = "DROP TABLE test_table;" + assert find_sql_refs(sql) == {SQLRef(table="test_table")} + + def test_truncate_table(self) -> None: + sql = "TRUNCATE test_table;" + assert find_sql_refs(sql) == {SQLRef(table="test_table")} + + def test_copy_table_to_file(self) -> None: + sql = "COPY test_table TO 'output.csv';" + assert find_sql_refs(sql) == {SQLRef(table="test_table")} + + def test_copy_table_from_file(self) -> None: + sql = "COPY test_table FROM 'input.csv';" + assert find_sql_refs(sql) == {SQLRef(table="test_table")} + + def test_analyze_with_schema(self) -> None: + sql = "ANALYZE my_schema.test_table;" + assert find_sql_refs(sql) == { + SQLRef(table="test_table", schema="my_schema") + } diff --git a/tests/_ast/test_visitor.py b/tests/_ast/test_visitor.py index a2db4f70b17..60da8ddb55a 100644 --- a/tests/_ast/test_visitor.py +++ b/tests/_ast/test_visitor.py @@ -1435,7 +1435,7 @@ def test_sql_table_f_string() -> None: "drop table and reference it in same statement", "drop table schema.cars; select * from cars", set(), - {"mo", "cars"}, + {"mo", "cars", "schema.cars"}, ), ( "drop schema and reference table from it in same statement",