Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions marimo/_ast/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,11 +681,35 @@ def visit_Call(self, node: ast.Call) -> ast.Call:
self.generic_visit(node)
return node

# Try to process each statement individually
# For some SQL types (e.g., PIVOT with certain clauses),
# DuckDB's statement.query may fail, so we fall back to processing the full SQL
statement_queries: list[str] = []
use_full_sql = False
for statement in statements:
try:
statement_sql = statement.query
# Skip empty statements
if statement_sql.strip():
statement_queries.append(statement_sql)
except (IndexError, BaseException):
# Fallback to full SQL if we can't extract any individual statement
use_full_sql = True
break

# If we couldn't extract individual statements, process the full SQL once
if use_full_sql or not statement_queries:
statement_queries = [sql]

# Accumulate defined names across all statements in this SQL block
# so that later statements don't create refs to tables defined in earlier statements
defined_names: set[str] = set()

for statement_sql in statement_queries:
# Parse the refs and defs of each statement
# Add all tables/dbs created in the query to the defs
try:
sql_defs = find_sql_defs(sql)
sql_defs = find_sql_defs(statement_sql)
except duckdb.ProgrammingError:
sql_defs = SQLDefs()
except BaseException as e:
Expand All @@ -695,13 +719,11 @@ def visit_Call(self, node: ast.Call) -> ast.Call:
exception=e,
node=node,
rule_code="MF005",
sql_content=sql,
sql_content=statement_sql,
context="sql_defs_extraction",
)
sql_defs = SQLDefs()

defined_names = set()

for _table in sql_defs.tables:
self._define(
None,
Expand Down Expand Up @@ -730,7 +752,7 @@ def visit_Call(self, node: ast.Call) -> ast.Call:
sql_refs: set[SQLRef] = set()
try:
# Take results
sql_refs = find_sql_refs_cached(statement.query)
sql_refs = find_sql_refs_cached(statement_sql)
except (
duckdb.ProgrammingError,
duckdb.IOException,
Expand All @@ -744,7 +766,7 @@ def visit_Call(self, node: ast.Call) -> ast.Call:
exception=e,
node=first_arg,
rule_code="MF005",
sql_content=statement.query,
sql_content=statement_sql,
)

for ref in sql_refs:
Expand Down
60 changes: 60 additions & 0 deletions tests/_ast/test_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,3 +1536,63 @@ def test_sql_table_deleted_in_same_statement(
expected_defs.add("df")
assert v.defs == expected_defs, f"Failed for: {description}"
assert v.refs == expected_refs, f"Failed for: {description}"


@pytest.mark.skipif(not HAS_DEPS, reason="Requires duckdb")
@pytest.mark.parametrize(
(
"description",
"sql_statement",
"expected_refs",
),
[
(
"PIVOT with GROUP BY",
"pivot test_duck on function_type using count(*) group by categories",
{"mo", "test_duck"},
),
(
"PIVOT minimal",
"pivot test_duck on function_type",
{"mo", "test_duck"},
),
(
"PIVOT with USING",
"pivot test_duck on column_name using sum(value)",
{"mo", "test_duck"},
),
(
"UNPIVOT basic",
"unpivot test_duck on function_oid into name fld value val",
{"mo", "test_duck"},
),
(
"DESCRIBE table",
"describe test_duck",
{"mo", "test_duck"},
),
(
"SUMMARIZE table",
"summarize test_duck",
{"mo", "test_duck"},
),
(
"Multiple statements with PIVOT workaround",
"from test_duck limit 0; pivot test_duck on function_type",
{"mo", "test_duck"},
),
],
)
def test_sql_pivot_unpivot_commands(
description: str,
sql_statement: str,
expected_refs: set[str],
) -> None:
"""Test PIVOT, UNPIVOT, DESCRIBE, and SUMMARIZE commands (issue #6533)."""
code = f"df = mo.sql('{sql_statement}')"
v = visitor.ScopedVisitor()
mod = ast.parse(code)
v.visit(mod)

assert v.defs == {"df"}, f"Failed for: {description}"
assert v.refs == expected_refs, f"Failed for: {description}"
Loading