Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/_data/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from marimo._data.models import DataTableColumn
from tests.utils import assert_serialize_roundtrip


def test_data_table_column_post_init() -> None:
Expand All @@ -9,3 +10,5 @@ def test_data_table_column_post_init() -> None:
sample_values=[],
)
assert column.name == "123"

assert_serialize_roundtrip(column)
24 changes: 20 additions & 4 deletions tests/_data/test_preview_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from marimo._runtime.requests import PreviewDatasetColumnRequest
from marimo._utils.platform import is_windows
from tests.mocks import snapshotter
from tests.utils import assert_serialize_roundtrip

if TYPE_CHECKING:
from collections.abc import Generator
Expand Down Expand Up @@ -84,6 +85,7 @@ def test_get_column_preview_for_dataframe() -> None:
assert result.chart_spec is not None
assert result.stats is not None
assert result.error is None
assert_serialize_roundtrip(result)


@pytest.mark.skipif(
Expand Down Expand Up @@ -148,6 +150,8 @@ def test_get_column_preview(column_name: str, snapshot_prefix: str) -> None:
# Verify vegafusion was checked
mock_dm.vegafusion.has.assert_called_once()

assert_serialize_roundtrip(result)

result_with_vegafusion = get_column_preview_dataset(
table=get_table_manager(df),
table_name="table",
Expand All @@ -168,6 +172,7 @@ def test_get_column_preview(column_name: str, snapshot_prefix: str) -> None:
)
assert result_with_vegafusion.chart_code == result.chart_code
assert result_with_vegafusion.chart_spec != result.chart_spec
assert_serialize_roundtrip(result_with_vegafusion)


@pytest.mark.skipif(
Expand All @@ -194,6 +199,7 @@ def test_get_column_preview_for_duckdb() -> None:
assert result is not None
assert result.stats is not None
assert result.error is None
assert_serialize_roundtrip(result)

# Check if summary contains expected statistics for the alternating pattern
assert result.stats.total == 100
Expand All @@ -211,6 +217,8 @@ def test_get_column_preview_for_duckdb() -> None:
assert result_id.error is None
assert result_id.chart_spec is not None

assert_serialize_roundtrip(result_id)

# Not implemented yet
assert result.chart_code is None

Expand Down Expand Up @@ -243,6 +251,7 @@ def test_get_column_preview_for_duckdb_categorical() -> None:
assert result_categorical is not None
assert result_categorical.stats is not None
assert result_categorical.error is None
assert_serialize_roundtrip(result_categorical)

# Check if summary contains expected statistics for the categorical pattern
assert result_categorical.stats.total == 100
Expand Down Expand Up @@ -292,6 +301,7 @@ def test_get_column_preview_for_duckdb_date() -> None:
assert result_date is not None
assert result_date.stats is not None
assert result_date.error is None
assert_serialize_roundtrip(result_date)

# Check if summary contains expected statistics for the date pattern
assert result_date.stats.total == 100
Expand Down Expand Up @@ -343,6 +353,7 @@ def test_get_column_preview_for_duckdb_datetime() -> None:
assert result_datetime is not None
assert result_datetime.stats is not None
assert result_datetime.error is None
assert_serialize_roundtrip(result_datetime)

# Check if summary contains expected statistics for the datetime pattern
assert result_datetime.stats.total == 100
Expand Down Expand Up @@ -393,6 +404,7 @@ def test_get_column_preview_for_duckdb_time() -> None:
assert result_time is not None
assert result_time.stats is not None
assert result_time.error is None
assert_serialize_roundtrip(result_time)

# Check if summary contains expected statistics for the time pattern
assert result_time.stats.total == 100
Expand Down Expand Up @@ -429,6 +441,7 @@ def test_get_column_preview_for_duckdb_bool() -> None:
assert result_bool is not None
assert result_bool.stats is not None
assert result_bool.error is None
assert_serialize_roundtrip(result_bool)

# Check if summary contains expected statistics for the boolean pattern
assert result_bool.stats.total == 100
Expand Down Expand Up @@ -477,6 +490,7 @@ def test_get_column_preview_for_duckdb_over_limit() -> None:
)
assert result.missing_packages == ["vegafusion", "vl_convert_python"]
assert result.chart_spec is None
assert_serialize_roundtrip(result)

# Not implemented yet
assert result.chart_code is None
Expand All @@ -499,7 +513,6 @@ def test_sanitize_dtypes() -> None:
# Sanitize the dtypes
result = _sanitize_data(nw_df, "cat_col")
assert result.collect_schema()["cat_col"] == nw.String

result = _sanitize_data(nw_df, "int128_col")
assert result.collect_schema()["int128_col"] == nw.Int64

Expand Down Expand Up @@ -555,6 +568,9 @@ def test_preview_column_duration_dtype() -> None:
table_name="table",
column_name=column_name,
)
assert result is not None
assert result.chart_code is not None
assert result.chart_spec is not None

assert result is not None
assert result.chart_code is not None
assert result.chart_spec is not None

assert_serialize_roundtrip(result)
11 changes: 8 additions & 3 deletions tests/_messaging/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
from typing import Any, Optional

from marimo._messaging.mimetypes import ConsoleMimeType
from marimo._messaging.ops import CellOp, MessageOperation
from marimo._messaging.ops import (
CellOp,
MessageOperation,
deserialize_kernel_message,
)
from marimo._messaging.types import KernelMessage, Stderr, Stream
from marimo._utils.parse_dataclass import parse_raw


class MockStream(Stream):
Expand All @@ -17,6 +20,8 @@ def __init__(self, stream: Optional[Stream] = None) -> None:

def write(self, data: KernelMessage) -> None:
self.messages.append(data)
# Attempt to deserialize the message to ensure it is valid
deserialize_kernel_message(data)

@property
def operations(self) -> list[dict[str, Any]]:
Expand All @@ -27,7 +32,7 @@ def operations(self) -> list[dict[str, Any]]:
@property
def parsed_operations(self) -> list[MessageOperation]:
return [
parse_raw(op_data, MessageOperation) for op_data in self.messages
deserialize_kernel_message(op_data) for op_data in self.messages
]

@property
Expand Down
4 changes: 3 additions & 1 deletion tests/_runtime/test_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import marimo
from marimo._dependencies.dependencies import DependencyManager
from marimo._messaging.ops import CompletionResult
from marimo._messaging.ops import CompletionResult, deserialize_kernel_message
from marimo._messaging.types import KernelMessage, Stream
from marimo._runtime.complete import (
_build_docstring_cached,
Expand Down Expand Up @@ -367,6 +367,8 @@ def __init__(self):

def write(self, data: KernelMessage) -> None:
self.messages.append(data)
# Attempt to deserialize the message to ensure it is valid
deserialize_kernel_message(data)

@property
def operations(self) -> list[dict[str, Any]]:
Expand Down
22 changes: 17 additions & 5 deletions tests/_server/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations

import os
import sys
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING

Expand All @@ -14,9 +14,11 @@
from marimo._config.utils import CONFIG_FILENAME
from marimo._server.api.deps import AppState
from marimo._server.main import create_starlette_app
from marimo._server.session.session_view import SessionView
from marimo._server.sessions import SessionManager
from marimo._server.utils import initialize_asyncio
from tests._server.mocks import get_mock_session_manager
from tests.utils import assert_serialize_roundtrip

if TYPE_CHECKING:
from collections.abc import Generator, Iterator
Expand All @@ -39,16 +41,15 @@ def client_with_lifespans() -> Generator[TestClient, None, None]:
@pytest.fixture
def user_config_manager() -> Iterator[UserConfigManager]:
tmp = TemporaryDirectory()
config_path = os.path.join(tmp.name, CONFIG_FILENAME)
with open(config_path, "w") as f:
f.write("")
config_path = Path(tmp.name) / CONFIG_FILENAME
config_path.write_text("")

class TestUserConfigManager(UserConfigManager):
def __init__(self) -> None:
super().__init__()

def get_config_path(self) -> str:
return config_path
return str(config_path)

yield TestUserConfigManager()

Expand Down Expand Up @@ -89,3 +90,14 @@ def get_session_config_manager(client: TestClient) -> UserConfigManager:

def get_user_config_manager(client: TestClient) -> UserConfigManager:
return client.app.state.config_manager # type: ignore


@pytest.fixture
def session_view() -> Generator[SessionView, None, None]:
sv = SessionView()

yield sv

# Test all operations can be serialized/deserialized
for operation in sv.operations:
assert_serialize_roundtrip(operation)
21 changes: 7 additions & 14 deletions tests/_server/export/test_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def _assert_contents():
)


def test_export_as_html_with_serialization():
def test_export_as_html_with_serialization(session_view: SessionView):
"""Test HTML export uses new serialization approach correctly."""
app = App()

Expand All @@ -512,7 +512,6 @@ def cell_2():
return (mo,)

file_manager = AppFileManager.from_app(InternalApp(app))
session_view = SessionView()

# Add some test data to session view
cell_ids = list(file_manager.app.cell_manager.cell_ids())
Expand Down Expand Up @@ -574,7 +573,7 @@ def cell_2():
assert 'data-marimo="true"' in html


def test_export_as_html_without_code():
def test_export_as_html_without_code(session_view: SessionView):
"""Test HTML export clears code when include_code=False."""
app = App()

Expand All @@ -585,7 +584,6 @@ def test_cell():
return secret_value

file_manager = AppFileManager.from_app(InternalApp(app))
session_view = SessionView()

cell_ids = list(file_manager.app.cell_manager.cell_ids())
session_view.cell_operations[cell_ids[0]] = CellOp(
Expand Down Expand Up @@ -636,7 +634,7 @@ def test_cell():
# The exact format depends on template implementation


def test_export_as_html_with_files():
def test_export_as_html_with_files(session_view: SessionView):
"""Test HTML export includes virtual files."""
app = App()

Expand All @@ -645,7 +643,6 @@ def test_cell():
return "test"

file_manager = AppFileManager.from_app(InternalApp(app))
session_view = SessionView()

cell_ids = list(file_manager.app.cell_manager.cell_ids())
session_view.cell_operations[cell_ids[0]] = CellOp(
Expand Down Expand Up @@ -684,7 +681,7 @@ def test_cell():
assert "data:" in html


def test_export_as_html_with_cell_configs():
def test_export_as_html_with_cell_configs(session_view: SessionView):
"""Test HTML export preserves cell configurations through serialization."""
app = App()

Expand All @@ -693,7 +690,6 @@ def configured_cell():
return "configured"

file_manager = AppFileManager.from_app(InternalApp(app))
session_view = SessionView()

cell_ids = list(file_manager.app.cell_manager.cell_ids())
session_view.cell_operations[cell_ids[0]] = CellOp(
Expand Down Expand Up @@ -733,7 +729,7 @@ def configured_cell():
assert "configured" in html


def test_export_as_html_preserves_output_order():
def test_export_as_html_preserves_output_order(session_view: SessionView):
"""Test HTML export preserves cell execution order in session snapshot."""
app = App()

Expand All @@ -750,7 +746,6 @@ def cell_third():
return "third"

file_manager = AppFileManager.from_app(InternalApp(app))
session_view = SessionView()

cell_ids = list(file_manager.app.cell_manager.cell_ids())

Expand Down Expand Up @@ -792,7 +787,7 @@ def cell_third():
assert "output_2" in html


def test_export_as_html_with_error_outputs():
def test_export_as_html_with_error_outputs(session_view: SessionView):
"""Test HTML export handles error outputs correctly."""
app = App()

Expand All @@ -801,7 +796,6 @@ def error_cell():
raise ValueError("Test error")

file_manager = AppFileManager.from_app(InternalApp(app))
session_view = SessionView()

cell_ids = list(file_manager.app.cell_manager.cell_ids())

Expand Down Expand Up @@ -846,7 +840,7 @@ def error_cell():
assert "Test error" in html or "ValueError" in html


def test_export_as_html_code_hash_consistency():
def test_export_as_html_code_hash_consistency(session_view: SessionView):
"""Test HTML export includes correct code hash regardless of include_code setting."""
app = App()

Expand All @@ -855,7 +849,6 @@ def test_cell():
return "test"

file_manager = AppFileManager.from_app(InternalApp(app))
session_view = SessionView()

cell_ids = list(file_manager.app.cell_manager.cell_ids())
session_view.cell_operations[cell_ids[0]] = CellOp(
Expand Down
12 changes: 3 additions & 9 deletions tests/_server/session/snapshots/mixed_error_session.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,18 @@
"outputs": [
{
"type": "error",
"ename": "ValueError",
"ename": "exception",
"evalue": "Invalid value",
"traceback": ["line 1", "line 2"]
"traceback": []
},
{
"type": "error",
"ename": "RuntimeError",
"evalue": "Runtime error occurred",
"traceback": []
},
{
"type": "error",
"ename": "TypeError",
"evalue": "Type mismatch",
"traceback": []
}
],
"console": []
}
]
}
}
Loading
Loading