Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions marimo/_runtime/context/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
if TYPE_CHECKING:
from collections.abc import Iterator

import duckdb

from marimo._ast.app import (
AppKernelRunnerRegistry,
InternalApp,
Expand Down Expand Up @@ -67,6 +69,16 @@ class ExecutionContext:
local_cell_id: Optional[CellId_t] = None
# output object set imperatively
output: Optional[list[Html]] = None
duckdb_connection: duckdb.DuckDBPyConnection | None = None

@contextmanager
def with_connection(
self, connection: duckdb.DuckDBPyConnection
) -> Iterator[None]:
old_conn = self.duckdb_connection
self.duckdb_connection = connection
yield
self.duckdb_connection = old_conn


@dataclass
Expand Down
14 changes: 14 additions & 0 deletions marimo/_runtime/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ def interrupt_handler(signum: int, frame: Any) -> None:
# probability of that happening is low.
if context.execution_context is not None:
Interrupted().broadcast()
# DuckDB connections are sometimes left in an inconsistent
# state when interrupted by a SIGINT. Manually interrupting
# duckdb through its own API seems to be safer.
if context.execution_context.duckdb_connection is not None:
try:
context.execution_context.duckdb_connection.interrupt()
except Exception as e:
# Coarse try/except; let's not kill the kernel if something
# goes wrong.
LOGGER.warning(
"Failed to interrupt running duckdb connection. This "
"may be a bug in duckdb or marimo. %s",
e,
)
raise MarimoInterrupt

return interrupt_handler
Expand Down
59 changes: 59 additions & 0 deletions marimo/_smoke_tests/sql/duckdb_interrupt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import marimo

__generated_with = "0.16.5"
app = marimo.App(width="medium")


@app.cell
def _():
FILE = "https://data.source.coop/cholmes/eurocrops/unprojected/geoparquet/FR_2018_EC21.parquet"
return (FILE,)


@app.cell
def _():
import marimo as mo
return (mo,)


@app.cell
def _(mo):
_df = mo.sql(
f"""
INSTALL spatial;
LOAD spatial;
"""
)
return


@app.cell
def _():
10
return


@app.cell
def _(mo):
_df = mo.sql(
f"""
SELECT 1;
"""
)
return


@app.cell
def _(FILE, mo, null):
_df = mo.sql(
f"""
CREATE TABLE gdf AS
SELECT *
FROM '{FILE}'
"""
)
return


if __name__ == "__main__":
app.run()
47 changes: 42 additions & 5 deletions marimo/_sql/engines/duckdb.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations

from contextlib import contextmanager, nullcontext
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast

from marimo import _loggers
from marimo._data.get_datasets import get_databases_from_duckdb
from marimo._data.models import Database, DataTable
from marimo._dependencies.dependencies import DependencyManager
from marimo._runtime.context.types import (
ContextNotInitializedError,
get_context,
)
from marimo._sql.engines.types import InferenceConfig, SQLConnection
from marimo._sql.utils import convert_to_output, wrapped_sql
from marimo._types.ids import VariableName

LOGGER = _loggers.marimo_logger()

if TYPE_CHECKING:
from collections.abc import Iterator

import duckdb

# Internal engine names
Expand All @@ -30,6 +37,24 @@ def __init__(
) -> None:
super().__init__(connection, engine_name)

@contextmanager
def _install_connection(
self, connection: duckdb.DuckDBPyConnection
) -> Iterator[None]:
try:
ctx = get_context()
except ContextNotInitializedError:
execution_context = None
else:
execution_context = ctx.execution_context
mgr = (
execution_context.with_connection
if execution_context is not None
else nullcontext
)
with mgr(connection):
yield

@property
def source(self) -> str:
return "duckdb"
Expand Down Expand Up @@ -87,8 +112,11 @@ def get_default_database(self) -> Optional[str]:
try:
import duckdb

connection = self._connection or duckdb
row = connection.sql("SELECT CURRENT_DATABASE()").fetchone()
connection = cast(
duckdb.DuckDBPyConnection, self._connection or duckdb
)
with self._install_connection(connection):
row = connection.sql("SELECT CURRENT_DATABASE()").fetchone()
if row is not None and row[0] is not None:
return str(row[0])
return None
Expand All @@ -100,8 +128,11 @@ def get_default_schema(self) -> Optional[str]:
try:
import duckdb

connection = self._connection or duckdb
row = connection.sql("SELECT CURRENT_SCHEMA()").fetchone()
connection = cast(
duckdb.DuckDBPyConnection, self._connection or duckdb
)
with self._install_connection(connection):
row = connection.sql("SELECT CURRENT_SCHEMA()").fetchone()
if row is not None and row[0] is not None:
return str(row[0])
return None
Expand All @@ -118,7 +149,13 @@ def get_databases(
) -> list[Database]:
"""Fetch all databases from the engine. At the moment, will fetch everything."""
_, _, _ = include_schemas, include_tables, include_table_details
return get_databases_from_duckdb(self._connection, self._engine_name)
import duckdb

connection = cast(
duckdb.DuckDBPyConnection, self._connection or duckdb
)
with self._install_connection(connection):
return get_databases_from_duckdb(connection, self._engine_name)

def get_tables_in_schema(
self, *, schema: str, database: str, include_table_details: bool
Expand Down
16 changes: 12 additions & 4 deletions marimo/_sql/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations

from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast

from marimo import _loggers
Expand Down Expand Up @@ -43,11 +44,18 @@ def wrapped_sql(
except ContextNotInitializedError:
relation = connection.sql(query=query)
else:
relation = eval(
"connection.sql(query=query)",
ctx.globals,
{"query": query, "connection": connection},
install_connection = (
ctx.execution_context.with_connection
if ctx.execution_context is not None
else nullcontext
)
with install_connection(connection):
relation = eval(
"connection.sql(query=query)",
ctx.globals,
{"query": query, "connection": connection},
)

return relation


Expand Down
Empty file added tests/_runtime/__init__.py
Empty file.
94 changes: 94 additions & 0 deletions tests/_runtime/test_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2025 Marimo. All rights reserved.
from __future__ import annotations

import signal
from unittest.mock import MagicMock, patch

import pytest

from marimo._dependencies.dependencies import DependencyManager
from marimo._runtime.context.types import ExecutionContext
from marimo._runtime.handlers import construct_interrupt_handler
from marimo._runtime.runtime import MarimoInterrupt

HAS_DUCKDB = DependencyManager.duckdb.has()


@pytest.mark.skipif(not HAS_DUCKDB, reason="DuckDB not installed")
def test_duckdb_interrupt_handler_called_when_connection_present():
"""Test that duckdb.interrupt() is called when a connection is present."""
import duckdb

# Create a mock connection that we can spy on
mock_conn = MagicMock(spec=duckdb.DuckDBPyConnection)

# Create an execution context with a duckdb connection
exec_ctx = ExecutionContext(cell_id="cell_id", setting_element_value=False)

# Mock the context to return our execution context
with patch("marimo._runtime.handlers.get_context") as mock_get_context:
mock_context = MagicMock()
mock_context.execution_context = exec_ctx
mock_get_context.return_value = mock_context

# Verify interrupt() is called when connection is set
with exec_ctx.with_connection(mock_conn):
interrupt_handler = construct_interrupt_handler(mock_context)

# Trigger the interrupt handler
with pytest.raises(MarimoInterrupt):
interrupt_handler(signal.SIGINT, None)

# Verify duckdb connection's interrupt was called
mock_conn.interrupt.assert_called_once()


@pytest.mark.skipif(not HAS_DUCKDB, reason="DuckDB not installed")
def test_duckdb_interrupt_handler_no_error_when_connection_none():
"""Test that no error occurs when connection is None."""
# Create an execution context without a connection
exec_ctx = ExecutionContext(cell_id="cell_id", setting_element_value=False)
exec_ctx.duckdb_connection = None

# Mock the context to return our execution context
with patch("marimo._runtime.handlers.get_context") as mock_get_context:
mock_context = MagicMock()
mock_context.execution_context = exec_ctx
mock_get_context.return_value = mock_context

interrupt_handler = construct_interrupt_handler(mock_context)

# Should not raise error from duckdb interrupt (only MarimoInterrupt)
with pytest.raises(MarimoInterrupt):
interrupt_handler(signal.SIGINT, None)


@pytest.mark.skipif(not HAS_DUCKDB, reason="DuckDB not installed")
def test_duckdb_interrupt_handler_exception_handling():
"""Test that exceptions during interrupt() don't crash the kernel."""
import duckdb

# Create a mock connection that raises an exception
mock_conn = MagicMock(spec=duckdb.DuckDBPyConnection)
mock_conn.interrupt.side_effect = RuntimeError("Mock error")

# Create an execution context with a duckdb connection
exec_ctx = ExecutionContext(cell_id="cell_id", setting_element_value=False)

# Mock the context to return our execution context
with patch("marimo._runtime.handlers.get_context") as mock_get_context:
mock_context = MagicMock()
mock_context.execution_context = exec_ctx
mock_get_context.return_value = mock_context

# Make interrupt() raise an exception
with exec_ctx.with_connection(mock_conn):
interrupt_handler = construct_interrupt_handler(mock_context)

# Should raise MarimoInterrupt, not RuntimeError
# The RuntimeError should be caught and logged
with pytest.raises(MarimoInterrupt):
interrupt_handler(signal.SIGINT, None)

# Verify interrupt was attempted
mock_conn.interrupt.assert_called_once()
Loading