Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions marimo/_runtime/context/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
if TYPE_CHECKING:
from collections.abc import Iterator

import duckdb

from marimo._ast.app import (
AppKernelRunnerRegistry,
InternalApp,
Expand Down Expand Up @@ -67,6 +69,16 @@ class ExecutionContext:
local_cell_id: Optional[CellId_t] = None
# output object set imperatively
output: Optional[list[Html]] = None
duckdb_connection: duckdb.DuckDBPyConnection | None = None

@contextmanager
def with_connection(
self, connection: duckdb.DuckDBPyConnection
) -> Iterator[None]:
old_conn = self.duckdb_connection
self.duckdb_connection = connection
yield
self.duckdb_connection = old_conn


@dataclass
Expand Down
14 changes: 14 additions & 0 deletions marimo/_runtime/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ def interrupt_handler(signum: int, frame: Any) -> None:
# probability of that happening is low.
if context.execution_context is not None:
Interrupted().broadcast()
# DuckDB connections are sometimes left in an inconsistent
# state when interrupted by a SIGINT. Manually interrupting
# duckdb through its own API seems to be safer.
if context.execution_context.duckdb_connection is not None:
try:
context.execution_context.duckdb_connection.interrupt()
except Exception as e:
# Coarse try/except; let's not kill the kernel if something
# goes wrong.
LOGGER.warning(
"Failed to interrupt running duckdb connection. This "
"may be a bug in duckdb or marimo. %s",
e,
)
raise MarimoInterrupt

return interrupt_handler
Expand Down
16 changes: 12 additions & 4 deletions marimo/_sql/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations

from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast

from marimo import _loggers
Expand Down Expand Up @@ -43,11 +44,18 @@ def wrapped_sql(
except ContextNotInitializedError:
relation = connection.sql(query=query)
else:
relation = eval(
"connection.sql(query=query)",
ctx.globals,
{"query": query, "connection": connection},
install_connection = (
ctx.execution_context.with_connection
if ctx.execution_context is not None
else nullcontext
)
with install_connection(connection):
relation = eval(
"connection.sql(query=query)",
ctx.globals,
{"query": query, "connection": connection},
)

return relation


Expand Down
94 changes: 94 additions & 0 deletions tests/_runtime/test_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2025 Marimo. All rights reserved.
from __future__ import annotations

import signal
from unittest.mock import MagicMock, patch

import pytest

from marimo._dependencies.dependencies import DependencyManager
from marimo._runtime.context.types import ExecutionContext
from marimo._runtime.handlers import construct_interrupt_handler
from marimo._runtime.runtime import MarimoInterrupt

HAS_DUCKDB = DependencyManager.duckdb.has()


@pytest.mark.skipif(not HAS_DUCKDB, reason="DuckDB not installed")
def test_duckdb_interrupt_handler_called_when_connection_present():
"""Test that duckdb.interrupt() is called when a connection is present."""
import duckdb

# Create a mock connection that we can spy on
mock_conn = MagicMock(spec=duckdb.DuckDBPyConnection)

# Create an execution context with a duckdb connection
exec_ctx = ExecutionContext(cell_id="cell_id", setting_element_value=False)

# Mock the context to return our execution context
with patch("marimo._runtime.handlers.get_context") as mock_get_context:
mock_context = MagicMock()
mock_context.execution_context = exec_ctx
mock_get_context.return_value = mock_context

# Verify interrupt() is called when connection is set
with exec_ctx.with_connection(mock_conn):
interrupt_handler = construct_interrupt_handler(mock_context)

# Trigger the interrupt handler
with pytest.raises(MarimoInterrupt):
interrupt_handler(signal.SIGINT, None)

# Verify duckdb connection's interrupt was called
mock_conn.interrupt.assert_called_once()


@pytest.mark.skipif(not HAS_DUCKDB, reason="DuckDB not installed")
def test_duckdb_interrupt_handler_no_error_when_connection_none():
"""Test that no error occurs when connection is None."""
# Create an execution context without a connection
exec_ctx = ExecutionContext(cell_id="cell_id", setting_element_value=False)
exec_ctx.duckdb_connection = None

# Mock the context to return our execution context
with patch("marimo._runtime.handlers.get_context") as mock_get_context:
mock_context = MagicMock()
mock_context.execution_context = exec_ctx
mock_get_context.return_value = mock_context

interrupt_handler = construct_interrupt_handler(mock_context)

# Should not raise error from duckdb interrupt (only MarimoInterrupt)
with pytest.raises(MarimoInterrupt):
interrupt_handler(signal.SIGINT, None)


@pytest.mark.skipif(not HAS_DUCKDB, reason="DuckDB not installed")
def test_duckdb_interrupt_handler_exception_handling():
"""Test that exceptions during interrupt() don't crash the kernel."""
import duckdb

# Create a mock connection that raises an exception
mock_conn = MagicMock(spec=duckdb.DuckDBPyConnection)
mock_conn.interrupt.side_effect = RuntimeError("Mock error")

# Create an execution context with a duckdb connection
exec_ctx = ExecutionContext(cell_id="cell_id", setting_element_value=False)

# Mock the context to return our execution context
with patch("marimo._runtime.handlers.get_context") as mock_get_context:
mock_context = MagicMock()
mock_context.execution_context = exec_ctx
mock_get_context.return_value = mock_context

# Make interrupt() raise an exception
with exec_ctx.with_connection(mock_conn):
interrupt_handler = construct_interrupt_handler(mock_context)

# Should raise MarimoInterrupt, not RuntimeError
# The RuntimeError should be caught and logged
with pytest.raises(MarimoInterrupt):
interrupt_handler(signal.SIGINT, None)

# Verify interrupt was attempted
mock_conn.interrupt.assert_called_once()
Loading