Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 4 additions & 16 deletions marimo/_ai/_tools/tools/cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional

from marimo._ai._tools.base import ToolBase
from marimo._ai._tools.types import SuccessResult
from marimo._ai._tools.utils.exceptions import ToolExecutionError
from marimo._ast.models import CellData
from marimo._messaging.ops import VariableValue
from marimo._types.ids import CellId_t, SessionId

if TYPE_CHECKING:
Expand Down Expand Up @@ -67,16 +68,7 @@ class CellRuntimeMetadata:
execution_time: Optional[float] = None


@dataclass
class CellVariableValue:
name: str
# Cell variables can be arbitrary Python values (int, str, list, dict, ...),
# so we keep this as Any to reflect actual runtime.
value: Optional[Any] = None
data_type: Optional[str] = None


CellVariables = dict[str, CellVariableValue]
CellVariables = dict[str, VariableValue]


@dataclass
Expand Down Expand Up @@ -381,10 +373,6 @@ def _get_cell_variables(
for var_name in cell_defs:
if var_name in all_variables:
var_value = all_variables[var_name]
cell_variables[var_name] = CellVariableValue(
name=var_name,
value=var_value.value,
data_type=var_value.datatype,
)
cell_variables[var_name] = var_value

return cell_variables
10 changes: 5 additions & 5 deletions marimo/_ai/_tools/tools/tables_and_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import Optional

from marimo._ai._tools.base import ToolBase
from marimo._ai._tools.tools.cells import CellVariableValue
from marimo._ai._tools.types import SuccessResult
from marimo._data.models import DataTableColumn
from marimo._messaging.ops import VariableValue
from marimo._server.sessions import Session
from marimo._types.ids import SessionId

Expand Down Expand Up @@ -42,7 +42,7 @@ class DataTableMetadata:
@dataclass
class TablesAndVariablesOutput(SuccessResult):
tables: dict[str, DataTableMetadata] = field(default_factory=dict)
variables: dict[str, CellVariableValue] = field(default_factory=dict)
variables: dict[str, VariableValue] = field(default_factory=dict)


class GetTablesAndVariables(
Expand Down Expand Up @@ -98,13 +98,13 @@ def _get_tables_and_variables(
engine=table.engine,
)

notebook_variables: dict[str, CellVariableValue] = {}
notebook_variables: dict[str, VariableValue] = {}
for variable_name in filtered_variables:
value = variables[variable_name]
notebook_variables[variable_name] = CellVariableValue(
notebook_variables[variable_name] = VariableValue(
name=variable_name,
value=value.value,
data_type=value.datatype,
datatype=value.datatype,
)

return TablesAndVariablesOutput(
Expand Down
19 changes: 9 additions & 10 deletions marimo/_data/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from datetime import date, datetime, time, timedelta # noqa: TCH003
from typing import TYPE_CHECKING, Any, Literal, Optional, Union

import msgspec

from marimo._types.ids import VariableName
from marimo._utils.msgspec_basestruct import BaseStruct

DataType = Literal[
"string",
Expand All @@ -23,7 +22,7 @@
ExternalDataType = str


class DataTableColumn(msgspec.Struct):
class DataTableColumn(BaseStruct):
"""
Represents a column in a data table.

Expand Down Expand Up @@ -54,7 +53,7 @@ def __post_init__(self) -> None:
DataTableType = Literal["table", "view"]


class DataTable(msgspec.Struct):
class DataTable(BaseStruct):
"""
Represents a data table.

Expand Down Expand Up @@ -85,12 +84,12 @@ class DataTable(msgspec.Struct):
indexes: Optional[list[str]] = None


class Schema(msgspec.Struct):
class Schema(BaseStruct):
name: str
tables: list[DataTable]


class Database(msgspec.Struct):
class Database(BaseStruct):
"""
Represents a collection of schemas.

Expand Down Expand Up @@ -119,7 +118,7 @@ class Database(msgspec.Struct):
NonNestedLiteral = Any


class ColumnStats(msgspec.Struct):
class ColumnStats(BaseStruct):
"""
Represents stats for a column in a data table.
"""
Expand All @@ -141,7 +140,7 @@ class ColumnStats(msgspec.Struct):
p95: Optional[NonNestedLiteral] = None


class BinValue(msgspec.Struct):
class BinValue(BaseStruct):
"""
Represents bin values for a column in a data table. This is used for plotting.

Expand All @@ -156,7 +155,7 @@ class BinValue(msgspec.Struct):
count: int


class ValueCount(msgspec.Struct):
class ValueCount(BaseStruct):
"""
Represents a value and its count in a column in a data table.
Currently used for string columns.
Expand All @@ -170,7 +169,7 @@ class ValueCount(msgspec.Struct):
count: int


class DataSourceConnection(msgspec.Struct):
class DataSourceConnection(BaseStruct):
"""
Represents a data source connection.

Expand Down
3 changes: 2 additions & 1 deletion marimo/_messaging/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from marimo._runtime.layout.layout import LayoutConfig
from marimo._secrets.models import SecretKeysWithProvider
from marimo._types.ids import CellId_t, RequestId, WidgetModelId
from marimo._utils.msgspec_basestruct import BaseStruct
from marimo._utils.platform import is_pyodide, is_windows

LOGGER = loggers.marimo_logger()
Expand Down Expand Up @@ -501,7 +502,7 @@ class VariableDeclaration(msgspec.Struct):
used_by: list[CellId_t]


class VariableValue(msgspec.Struct):
class VariableValue(BaseStruct):
name: str
value: Optional[str]
datatype: Optional[str]
Expand Down
55 changes: 55 additions & 0 deletions marimo/_utils/msgspec_basestruct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2025 Marimo. All rights reserved.
from __future__ import annotations

from typing import TYPE_CHECKING, Any

import msgspec

if TYPE_CHECKING:
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema


class BaseStruct(msgspec.Struct):
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
# Lazy import pydantic_core
from pydantic_core import core_schema

# Build per-field schemas
tdf: dict[str, core_schema.TypedDictField] = {}
for f in msgspec.structs.fields(cls):
tdf[f.name] = core_schema.typed_dict_field(
schema=handler.generate_schema(f.type),
required=(
f.default is msgspec.UNSET
and getattr(f, "default_factory", msgspec.UNSET)
is msgspec.UNSET
),
)
td = core_schema.typed_dict_schema(tdf, total=True)

# Create a function to convert a msgspec.Struct to a dictionary.
def to_struct(values: dict[str, Any]) -> Any:
return msgspec.convert(values, cls, from_attributes=True)

# Create a chain schema to validate the dictionary and convert to the msgspec.Struct.
chain = core_schema.chain_schema(
[td, core_schema.no_info_plain_validator_function(to_struct)]
)

# Return the json or python schema.
return core_schema.json_or_python_schema(
json_schema=chain,
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(cls), # fast-path
chain,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
msgspec.to_builtins
),
)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ lsp = [

mcp = [
"mcp>=1.0.0; python_version >= '3.10'",
"pydantic>2; python_version >= '3.10'",
]


Expand Down Expand Up @@ -516,6 +517,7 @@ Ue = "Ue" # Used in one of the cell IDs
Nd = "Nd" # Confused with And
pn = "pn" # Panel
caf = "caf" # cafe
ser = "ser" # Used in pydantic-core package as part of a function name

[tool.typos.files]
extend-exclude = [
Expand Down
31 changes: 17 additions & 14 deletions tests/_ai/tools/tools/test_cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
CellErrors,
CellRuntimeMetadata,
CellVariables,
CellVariableValue,
GetCellRuntimeData,
GetLightweightCellMap,
)
from marimo._messaging.cell_output import CellChannel
from marimo._messaging.ops import VariableValue
from marimo._server.sessions import Session
from marimo._types.ids import CellId_t, SessionId


@dataclass
Expand Down Expand Up @@ -63,7 +64,7 @@ def __post_init__(self) -> None:


@dataclass
class MockSession:
class MockSession(Session):
session_view: MockSessionView


Expand All @@ -77,7 +78,7 @@ def test_get_cell_errors_no_cell_op():
tool = GetCellRuntimeData(ToolContext())
session = MockSession(MockSessionView())

result = tool._get_cell_errors(session, "missing")
result = tool._get_cell_errors(session, CellId_t("missing"))
assert result == CellErrors(has_errors=False, error_details=None)


Expand All @@ -90,7 +91,7 @@ def test_get_cell_errors_with_marimo_error():
cell_op = MockCellOp(output=output)
session = MockSession(MockSessionView(cell_operations={"c1": cell_op}))

result = tool._get_cell_errors(session, "c1")
result = tool._get_cell_errors(session, CellId_t("c1"))
assert result.has_errors is True
assert result.error_details is not None
assert result.error_details[0].type == "NameError"
Expand All @@ -102,7 +103,7 @@ def test_get_cell_errors_with_stderr():
cell_op = MockCellOp(console=[console_output])
session = MockSession(MockSessionView(cell_operations={"c1": cell_op}))

result = tool._get_cell_errors(session, "c1")
result = tool._get_cell_errors(session, CellId_t("c1"))
assert result.has_errors is True
assert result.error_details is not None
assert result.error_details[0].type == "STDERR"
Expand Down Expand Up @@ -130,7 +131,7 @@ def test_get_cell_metadata_basic():
)
)

result = tool._get_cell_metadata(session, "c1")
result = tool._get_cell_metadata(session, CellId_t("c1"))
assert result == CellRuntimeMetadata(
runtime_state="idle", execution_time=42.5
)
Expand All @@ -140,7 +141,7 @@ def test_get_cell_metadata_no_cell_op():
tool = GetCellRuntimeData(ToolContext())
session = MockSession(MockSessionView())

result = tool._get_cell_metadata(session, "missing")
result = tool._get_cell_metadata(session, CellId_t("missing"))
assert result == CellRuntimeMetadata(
runtime_state=None, execution_time=None
)
Expand All @@ -154,21 +155,21 @@ def test_get_cell_variables():
cell_data = Mock()
cell_data.cell = cell

var_x = VariableValue("x", 42, "int")
var_x = VariableValue("x", "42", "int")
var_y = VariableValue("y", "hi", "str")
var_z = VariableValue("z", [1], "list")
var_z = VariableValue("z", "[1]", "list")

session = MockSession(
MockSessionView(variable_values={"x": var_x, "y": var_y, "z": var_z})
)

result = tool._get_cell_variables(session, cell_data)
expected: CellVariables = {
"x": CellVariableValue(
name="x", value=var_x.value, data_type=var_x.datatype
"x": VariableValue(
name="x", value=var_x.value, datatype=var_x.datatype
),
"y": CellVariableValue(
name="y", value=var_y.value, data_type=var_y.datatype
"y": VariableValue(
name="y", value=var_y.value, datatype=var_y.datatype
),
}
assert result == expected
Expand Down Expand Up @@ -210,7 +211,9 @@ def test_get_cell_runtime_data_invalid_cell():
from marimo._ai._tools.tools.cells import GetCellRuntimeDataArgs
from marimo._ai._tools.utils.exceptions import ToolExecutionError

args = GetCellRuntimeDataArgs(session_id="test", cell_id="invalid")
args = GetCellRuntimeDataArgs(
session_id=SessionId("test"), cell_id=CellId_t("invalid")
)

with pytest.raises(ToolExecutionError) as exc_info:
tool.handle(args)
Expand Down
Loading
Loading