diff --git a/marimo/_ai/_tools/tools/cells.py b/marimo/_ai/_tools/tools/cells.py index c79bc230ff5..25c14f7e4d6 100644 --- a/marimo/_ai/_tools/tools/cells.py +++ b/marimo/_ai/_tools/tools/cells.py @@ -3,12 +3,13 @@ from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional from marimo._ai._tools.base import ToolBase from marimo._ai._tools.types import SuccessResult from marimo._ai._tools.utils.exceptions import ToolExecutionError from marimo._ast.models import CellData +from marimo._messaging.ops import VariableValue from marimo._types.ids import CellId_t, SessionId if TYPE_CHECKING: @@ -67,16 +68,7 @@ class CellRuntimeMetadata: execution_time: Optional[float] = None -@dataclass -class CellVariableValue: - name: str - # Cell variables can be arbitrary Python values (int, str, list, dict, ...), - # so we keep this as Any to reflect actual runtime. - value: Optional[Any] = None - data_type: Optional[str] = None - - -CellVariables = dict[str, CellVariableValue] +CellVariables = dict[str, VariableValue] @dataclass @@ -381,10 +373,6 @@ def _get_cell_variables( for var_name in cell_defs: if var_name in all_variables: var_value = all_variables[var_name] - cell_variables[var_name] = CellVariableValue( - name=var_name, - value=var_value.value, - data_type=var_value.datatype, - ) + cell_variables[var_name] = var_value return cell_variables diff --git a/marimo/_ai/_tools/tools/tables_and_variables.py b/marimo/_ai/_tools/tools/tables_and_variables.py index fa63ad4bdd7..c763afcf95e 100644 --- a/marimo/_ai/_tools/tools/tables_and_variables.py +++ b/marimo/_ai/_tools/tools/tables_and_variables.py @@ -3,9 +3,9 @@ from typing import Optional from marimo._ai._tools.base import ToolBase -from marimo._ai._tools.tools.cells import CellVariableValue from marimo._ai._tools.types import SuccessResult from marimo._data.models import DataTableColumn +from marimo._messaging.ops import VariableValue from marimo._server.sessions import Session from marimo._types.ids import SessionId @@ -42,7 +42,7 @@ class DataTableMetadata: @dataclass class TablesAndVariablesOutput(SuccessResult): tables: dict[str, DataTableMetadata] = field(default_factory=dict) - variables: dict[str, CellVariableValue] = field(default_factory=dict) + variables: dict[str, VariableValue] = field(default_factory=dict) class GetTablesAndVariables( @@ -98,13 +98,13 @@ def _get_tables_and_variables( engine=table.engine, ) - notebook_variables: dict[str, CellVariableValue] = {} + notebook_variables: dict[str, VariableValue] = {} for variable_name in filtered_variables: value = variables[variable_name] - notebook_variables[variable_name] = CellVariableValue( + notebook_variables[variable_name] = VariableValue( name=variable_name, value=value.value, - data_type=value.datatype, + datatype=value.datatype, ) return TablesAndVariablesOutput( diff --git a/marimo/_data/models.py b/marimo/_data/models.py index c25575fe72a..6f5ab9b9467 100644 --- a/marimo/_data/models.py +++ b/marimo/_data/models.py @@ -4,9 +4,8 @@ from datetime import date, datetime, time, timedelta # noqa: TCH003 from typing import TYPE_CHECKING, Any, Literal, Optional, Union -import msgspec - from marimo._types.ids import VariableName +from marimo._utils.msgspec_basestruct import BaseStruct DataType = Literal[ "string", @@ -23,7 +22,7 @@ ExternalDataType = str -class DataTableColumn(msgspec.Struct): +class DataTableColumn(BaseStruct): """ Represents a column in a data table. @@ -54,7 +53,7 @@ def __post_init__(self) -> None: DataTableType = Literal["table", "view"] -class DataTable(msgspec.Struct): +class DataTable(BaseStruct): """ Represents a data table. @@ -85,12 +84,12 @@ class DataTable(msgspec.Struct): indexes: Optional[list[str]] = None -class Schema(msgspec.Struct): +class Schema(BaseStruct): name: str tables: list[DataTable] -class Database(msgspec.Struct): +class Database(BaseStruct): """ Represents a collection of schemas. @@ -119,7 +118,7 @@ class Database(msgspec.Struct): NonNestedLiteral = Any -class ColumnStats(msgspec.Struct): +class ColumnStats(BaseStruct): """ Represents stats for a column in a data table. """ @@ -141,7 +140,7 @@ class ColumnStats(msgspec.Struct): p95: Optional[NonNestedLiteral] = None -class BinValue(msgspec.Struct): +class BinValue(BaseStruct): """ Represents bin values for a column in a data table. This is used for plotting. @@ -156,7 +155,7 @@ class BinValue(msgspec.Struct): count: int -class ValueCount(msgspec.Struct): +class ValueCount(BaseStruct): """ Represents a value and its count in a column in a data table. Currently used for string columns. @@ -170,7 +169,7 @@ class ValueCount(msgspec.Struct): count: int -class DataSourceConnection(msgspec.Struct): +class DataSourceConnection(BaseStruct): """ Represents a data source connection. diff --git a/marimo/_messaging/ops.py b/marimo/_messaging/ops.py index 12f3c42af59..99126f55ff4 100644 --- a/marimo/_messaging/ops.py +++ b/marimo/_messaging/ops.py @@ -54,6 +54,7 @@ from marimo._runtime.layout.layout import LayoutConfig from marimo._secrets.models import SecretKeysWithProvider from marimo._types.ids import CellId_t, RequestId, WidgetModelId +from marimo._utils.msgspec_basestruct import BaseStruct from marimo._utils.platform import is_pyodide, is_windows LOGGER = loggers.marimo_logger() @@ -501,7 +502,7 @@ class VariableDeclaration(msgspec.Struct): used_by: list[CellId_t] -class VariableValue(msgspec.Struct): +class VariableValue(BaseStruct): name: str value: Optional[str] datatype: Optional[str] diff --git a/marimo/_utils/msgspec_basestruct.py b/marimo/_utils/msgspec_basestruct.py new file mode 100644 index 00000000000..0a4397acb42 --- /dev/null +++ b/marimo/_utils/msgspec_basestruct.py @@ -0,0 +1,55 @@ +# Copyright 2025 Marimo. All rights reserved. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import msgspec + +if TYPE_CHECKING: + from pydantic import GetCoreSchemaHandler + from pydantic_core import CoreSchema + + +class BaseStruct(msgspec.Struct): + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + # Lazy import pydantic_core + from pydantic_core import core_schema + + # Build per-field schemas + tdf: dict[str, core_schema.TypedDictField] = {} + for f in msgspec.structs.fields(cls): + tdf[f.name] = core_schema.typed_dict_field( + schema=handler.generate_schema(f.type), + required=( + f.default is msgspec.UNSET + and getattr(f, "default_factory", msgspec.UNSET) + is msgspec.UNSET + ), + ) + td = core_schema.typed_dict_schema(tdf, total=True) + + # Create a function to convert a msgspec.Struct to a dictionary. + def to_struct(values: dict[str, Any]) -> Any: + return msgspec.convert(values, cls, from_attributes=True) + + # Create a chain schema to validate the dictionary and convert to the msgspec.Struct. + chain = core_schema.chain_schema( + [td, core_schema.no_info_plain_validator_function(to_struct)] + ) + + # Return the json or python schema. + return core_schema.json_or_python_schema( + json_schema=chain, + python_schema=core_schema.union_schema( + [ + core_schema.is_instance_schema(cls), # fast-path + chain, + ] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + msgspec.to_builtins + ), + ) diff --git a/pyproject.toml b/pyproject.toml index 84f8cb86d3e..eaaf37f562e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,6 +122,7 @@ lsp = [ mcp = [ "mcp>=1.0.0; python_version >= '3.10'", + "pydantic>2; python_version >= '3.10'", ] @@ -516,6 +517,7 @@ Ue = "Ue" # Used in one of the cell IDs Nd = "Nd" # Confused with And pn = "pn" # Panel caf = "caf" # cafe +ser = "ser" # Used in pydantic-core package as part of a function name [tool.typos.files] extend-exclude = [ diff --git a/tests/_ai/tools/tools/test_cells.py b/tests/_ai/tools/tools/test_cells.py index 66efb2e495f..4d1b6a3fa0d 100644 --- a/tests/_ai/tools/tools/test_cells.py +++ b/tests/_ai/tools/tools/test_cells.py @@ -10,12 +10,13 @@ CellErrors, CellRuntimeMetadata, CellVariables, - CellVariableValue, GetCellRuntimeData, GetLightweightCellMap, ) from marimo._messaging.cell_output import CellChannel from marimo._messaging.ops import VariableValue +from marimo._server.sessions import Session +from marimo._types.ids import CellId_t, SessionId @dataclass @@ -63,7 +64,7 @@ def __post_init__(self) -> None: @dataclass -class MockSession: +class MockSession(Session): session_view: MockSessionView @@ -77,7 +78,7 @@ def test_get_cell_errors_no_cell_op(): tool = GetCellRuntimeData(ToolContext()) session = MockSession(MockSessionView()) - result = tool._get_cell_errors(session, "missing") + result = tool._get_cell_errors(session, CellId_t("missing")) assert result == CellErrors(has_errors=False, error_details=None) @@ -90,7 +91,7 @@ def test_get_cell_errors_with_marimo_error(): cell_op = MockCellOp(output=output) session = MockSession(MockSessionView(cell_operations={"c1": cell_op})) - result = tool._get_cell_errors(session, "c1") + result = tool._get_cell_errors(session, CellId_t("c1")) assert result.has_errors is True assert result.error_details is not None assert result.error_details[0].type == "NameError" @@ -102,7 +103,7 @@ def test_get_cell_errors_with_stderr(): cell_op = MockCellOp(console=[console_output]) session = MockSession(MockSessionView(cell_operations={"c1": cell_op})) - result = tool._get_cell_errors(session, "c1") + result = tool._get_cell_errors(session, CellId_t("c1")) assert result.has_errors is True assert result.error_details is not None assert result.error_details[0].type == "STDERR" @@ -130,7 +131,7 @@ def test_get_cell_metadata_basic(): ) ) - result = tool._get_cell_metadata(session, "c1") + result = tool._get_cell_metadata(session, CellId_t("c1")) assert result == CellRuntimeMetadata( runtime_state="idle", execution_time=42.5 ) @@ -140,7 +141,7 @@ def test_get_cell_metadata_no_cell_op(): tool = GetCellRuntimeData(ToolContext()) session = MockSession(MockSessionView()) - result = tool._get_cell_metadata(session, "missing") + result = tool._get_cell_metadata(session, CellId_t("missing")) assert result == CellRuntimeMetadata( runtime_state=None, execution_time=None ) @@ -154,9 +155,9 @@ def test_get_cell_variables(): cell_data = Mock() cell_data.cell = cell - var_x = VariableValue("x", 42, "int") + var_x = VariableValue("x", "42", "int") var_y = VariableValue("y", "hi", "str") - var_z = VariableValue("z", [1], "list") + var_z = VariableValue("z", "[1]", "list") session = MockSession( MockSessionView(variable_values={"x": var_x, "y": var_y, "z": var_z}) @@ -164,11 +165,11 @@ def test_get_cell_variables(): result = tool._get_cell_variables(session, cell_data) expected: CellVariables = { - "x": CellVariableValue( - name="x", value=var_x.value, data_type=var_x.datatype + "x": VariableValue( + name="x", value=var_x.value, datatype=var_x.datatype ), - "y": CellVariableValue( - name="y", value=var_y.value, data_type=var_y.datatype + "y": VariableValue( + name="y", value=var_y.value, datatype=var_y.datatype ), } assert result == expected @@ -210,7 +211,9 @@ def test_get_cell_runtime_data_invalid_cell(): from marimo._ai._tools.tools.cells import GetCellRuntimeDataArgs from marimo._ai._tools.utils.exceptions import ToolExecutionError - args = GetCellRuntimeDataArgs(session_id="test", cell_id="invalid") + args = GetCellRuntimeDataArgs( + session_id=SessionId("test"), cell_id=CellId_t("invalid") + ) with pytest.raises(ToolExecutionError) as exc_info: tool.handle(args) diff --git a/tests/_ai/tools/tools/test_tables_variables.py b/tests/_ai/tools/tools/test_tables_variables.py index ab9c0a2eb91..bdd8a593bb6 100644 --- a/tests/_ai/tools/tools/test_tables_variables.py +++ b/tests/_ai/tools/tools/test_tables_variables.py @@ -5,7 +5,6 @@ import pytest from marimo._ai._tools.base import ToolContext -from marimo._ai._tools.tools.cells import CellVariableValue from marimo._ai._tools.tools.tables_and_variables import ( DataTableMetadata, GetTablesAndVariables, @@ -13,6 +12,7 @@ ) from marimo._data.models import DataTableColumn from marimo._messaging.ops import VariableValue +from marimo._server.sessions import Session @dataclass @@ -39,7 +39,7 @@ class MockSessionView: @dataclass -class MockSession: +class MockSession(Session): session_view: MockSessionView @@ -53,9 +53,9 @@ def tool() -> GetTablesAndVariables: def sample_columns() -> list[DataTableColumn]: """Sample column information for testing.""" return [ - DataTableColumn("id", "int", "INTEGER", [1, 2, 3]), - DataTableColumn("name", "str", "VARCHAR", ["Alice", "Bob"]), - DataTableColumn("email", "str", "VARCHAR", ["alice@example.com"]), + DataTableColumn("id", "integer", "INTEGER", [1, 2, 3]), + DataTableColumn("name", "string", "VARCHAR", ["Alice", "Bob"]), + DataTableColumn("email", "string", "VARCHAR", ["alice@example.com"]), ] @@ -78,8 +78,8 @@ def sample_tables(sample_columns: list[DataTableColumn]) -> list[MockDataset]: num_rows=50, num_columns=2, columns=[ - DataTableColumn("order_id", "int", "INTEGER", [1, 2]), - DataTableColumn("user_id", "int", "INTEGER", [1, 2]), + DataTableColumn("order_id", "integer", "INTEGER", [1, 2]), + DataTableColumn("user_id", "integer", "INTEGER", [1, 2]), ], ), ] @@ -89,10 +89,10 @@ def sample_tables(sample_columns: list[DataTableColumn]) -> list[MockDataset]: def sample_variables() -> dict[str, VariableValue]: """Sample variable data for testing.""" return { - "x": VariableValue("x", 42, "int"), - "y": VariableValue("y", "hello", "str"), + "x": VariableValue("x", "42", "integer"), + "y": VariableValue("y", "hello", "string"), "df": VariableValue("df", None, "DataFrame"), - "my_list": VariableValue("my_list", [1, 2, 3], "list"), + "my_list": VariableValue("my_list", "[1, 2, 3]", "list"), } @@ -140,8 +140,8 @@ def test_get_tables_and_variables_empty_list( x_var = result.variables["x"] assert x_var.name == "x" - assert x_var.value == 42 - assert x_var.data_type == "int" + assert x_var.value == "42" + assert x_var.datatype == "integer" def test_get_tables_and_variables_specific_variables( @@ -214,30 +214,11 @@ def test_data_table_metadata_structure( id_column = users_table.columns[0] assert isinstance(id_column, DataTableColumn) assert id_column.name == "id" - assert id_column.type == "int" + assert id_column.type == "integer" assert id_column.external_type == "INTEGER" assert id_column.sample_values == [1, 2, 3] -def test_cell_variable_value_structure( - tool: GetTablesAndVariables, sample_session: MockSession -): - """Test that CellVariableValue is properly structured.""" - result = tool._get_tables_and_variables(sample_session, ["x", "my_list"]) - - x_var = result.variables["x"] - assert isinstance(x_var, CellVariableValue) - assert x_var.name == "x" - assert x_var.value == 42 - assert x_var.data_type == "int" - - list_var = result.variables["my_list"] - assert isinstance(list_var, CellVariableValue) - assert list_var.name == "my_list" - assert list_var.value == [1, 2, 3] - assert list_var.data_type == "list" - - def test_empty_session(tool: GetTablesAndVariables): """Test _get_tables_and_variables with empty session (no tables or variables).""" empty_session = MockSession( @@ -258,8 +239,8 @@ def test_table_with_no_primary_keys_or_indexes(tool: GetTablesAndVariables): num_rows=10, num_columns=2, columns=[ - DataTableColumn("col1", "str", "TEXT", ["a", "b"]), - DataTableColumn("col2", "int", "INTEGER", [1, 2]), + DataTableColumn("col1", "string", "TEXT", ["a", "b"]), + DataTableColumn("col2", "integer", "INTEGER", [1, 2]), ], primary_keys=None, indexes=None, @@ -297,7 +278,7 @@ def test_variable_with_none_value(tool: GetTablesAndVariables): none_var = result.variables["none_var"] assert none_var.name == "none_var" assert none_var.value is None - assert none_var.data_type == "NoneType" + assert none_var.datatype == "NoneType" def test_filtering_logic_separate_tables_and_variables( diff --git a/tests/_utils/test_msgspec_basestruct.py b/tests/_utils/test_msgspec_basestruct.py new file mode 100644 index 00000000000..5c51c463222 --- /dev/null +++ b/tests/_utils/test_msgspec_basestruct.py @@ -0,0 +1,86 @@ +import typing as t + +import msgspec + +from marimo._ai._tools.tools.cells import ( + GetCellRuntimeDataArgs, + GetCellRuntimeDataOutput, + GetLightweightCellMapArgs, + GetLightweightCellMapOutput, +) +from marimo._ai._tools.tools.datasource import ( + GetDatabaseTablesArgs, + GetDatabaseTablesOutput, +) +from marimo._ai._tools.tools.errors import ( + GetNotebookErrorsArgs, + GetNotebookErrorsOutput, +) +from marimo._ai._tools.tools.notebooks import ( + GetActiveNotebooksOutput, +) +from marimo._ai._tools.tools.tables_and_variables import ( + TablesAndVariablesArgs, + TablesAndVariablesOutput, +) + +TOOL_IO_CLASSES = [ + GetCellRuntimeDataArgs, + GetCellRuntimeDataOutput, + GetLightweightCellMapArgs, + GetLightweightCellMapOutput, + TablesAndVariablesArgs, + TablesAndVariablesOutput, + GetDatabaseTablesArgs, + GetDatabaseTablesOutput, + GetNotebookErrorsArgs, + GetNotebookErrorsOutput, + GetActiveNotebooksOutput, +] + + +def _iter_types(ann: t.Any): + stack = [ann] + seen: set[int] = set() + while stack: + tp = stack.pop() + obj_id = id(tp) + if obj_id in seen: + continue + seen.add(obj_id) + + if isinstance(tp, type): + yield tp + # Recurse into dataclass-like classes by following their annotations + anns = getattr(tp, "__annotations__", None) + if anns: + stack.extend(anns.values()) + continue + + origin = t.get_origin(tp) + if origin is not None: + stack.append(origin) + stack.extend(t.get_args(tp)) + + +def test_tool_msgspec_structs_expose_pydantic_hook() -> None: + offenders: list[str] = [] + for cls in TOOL_IO_CLASSES: + for ann in (getattr(cls, "__annotations__", {}) or {}).values(): + for resolved_type in _iter_types(ann): + if isinstance(resolved_type, type) and issubclass( + resolved_type, msgspec.Struct + ): + if not callable( + getattr( + resolved_type, "__get_pydantic_core_schema__", None + ) + ): + offenders.append( + f"{resolved_type.__module__}.{resolved_type.__name__} (referenced by {cls.__module__}.{cls.__name__})" + ) + + assert not offenders, ( + "These msgspec.Structs referenced by tools must use BaseStruct as their base class: " + + ", ".join(offenders) + )