Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,25 @@ make py-test
Run a specific test

```bash
hatch run test:test tests/_ast/
hatch run +py=3.13 test:test tests/_ast/
```

Run all changed tests

```bash
hatch run +py=3.13 test:test --picked
```

Run tests with optional dependencies

```bash
hatch run test-optional:test tests/_ast/
hatch run +py=3.13 test-optional:test tests/_ast/
```

Run tests with a specific Python version
Run tests across all Python versions (omit `+py`)

```bash
hatch run +py=3.10 test:test tests/_ast/
hatch run test:test tests/_ast/
```

Run all tests across all Python versions
Expand Down
20 changes: 10 additions & 10 deletions marimo/_data/preview_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def _get_altair_chart(
try:
chart_spec = _get_chart_spec(
# Downgrade to v1 since altair doesn't support v2 yet
# This is valiadted with our tests, so if the tests pass with this
# This is validated with our tests, so if the tests pass with this
# removed, we can remove the downgrade.
column_data=downgrade_narwhals_df_to_v1(column_data),
column_type=column_type,
Expand Down Expand Up @@ -369,49 +369,49 @@ def _sanitize_data(
total_seconds = diff.total_seconds()
if total_seconds >= 604800:
# Use weeks if range is at least a week
column_data = column_data.with_columns(
column_data = frame.with_columns(
(col.dt.total_seconds() / 604800).alias(
column_name
)
)
elif total_seconds >= 86400:
# Use days if range is at least a day
column_data = column_data.with_columns(
column_data = frame.with_columns(
(col.dt.total_seconds() / 86400).alias(column_name)
)
elif total_seconds >= 3600:
# Use hours if range is at least an hour
column_data = column_data.with_columns(
column_data = frame.with_columns(
(col.dt.total_seconds() / 3600).alias(column_name)
)
elif total_seconds >= 60:
# Use minutes if range is at least a minute
column_data = column_data.with_columns(
column_data = frame.with_columns(
col.dt.total_minutes().alias(column_name)
)
elif total_seconds >= 1:
# Use seconds if range is at least a second
column_data = column_data.with_columns(
column_data = frame.with_columns(
col.dt.total_seconds().alias(column_name)
)
elif total_seconds >= 0.001:
# Use milliseconds if range is at least a millisecond
column_data = column_data.with_columns(
column_data = frame.with_columns(
col.dt.total_milliseconds().alias(column_name)
)
elif total_seconds >= 0.000001:
# Use microseconds if range is at least a microsecond
column_data = column_data.with_columns(
column_data = frame.with_columns(
col.dt.total_microseconds().alias(column_name)
)
elif total_seconds >= 0.000000001:
# Use nanoseconds if range is at least a nanosecond
column_data = column_data.with_columns(
column_data = frame.with_columns(
col.dt.total_nanoseconds().alias(column_name)
)
except Exception as e:
LOGGER.warning("Failed to infer duration precision: %s", e)
column_data = column_data.with_columns(
column_data = frame.with_columns(
col.dt.total_seconds().alias(column_name)
)
except Exception as e:
Expand Down
26 changes: 20 additions & 6 deletions marimo/_plugins/ui/_impl/altair_chart.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2024 Marimo. All rights reserved.

from __future__ import annotations

import datetime
Expand Down Expand Up @@ -29,6 +30,7 @@
assert_can_narwhalify,
can_narwhalify,
empty_df,
is_narwhals_lazyframe,
)

if sys.version_info < (3, 10):
Expand Down Expand Up @@ -100,7 +102,11 @@ def _using_vegafusion() -> bool:
def _filter_dataframe(
native_df: Union[IntoDataFrame, IntoLazyFrame], selection: ChartSelection
) -> Union[IntoDataFrame, IntoLazyFrame]:
df = nw.from_native(native_df)
# Use lazy evaluation for efficient chained filtering
base = nw.from_native(native_df)
is_lazy = is_narwhals_lazyframe(base)
df = base.lazy()

if not isinstance(selection, dict):
raise TypeError("Input 'selection' must be a dictionary")

Expand All @@ -121,10 +127,12 @@ def _filter_dataframe(
vgsid = fields.get("_vgsid_", [])
try:
indexes = [int(i) - 1 for i in vgsid]
df = cast(nw.DataFrame[Any], df)[indexes]
# Need to collect for index-based selection
non_lazy = df.collect()[indexes]
df = non_lazy.lazy()
except IndexError:
# Out of bounds index, return empty dataframe if it's the
df = cast(nw.DataFrame[Any], df)[[]]
df = df.head(0)
LOGGER.error(f"Invalid index in selection: {vgsid}")
except ValueError:
LOGGER.error(f"Invalid index in selection: {vgsid}")
Expand All @@ -140,10 +148,12 @@ def _filter_dataframe(
if field in ("vlPoint", "_vgsid_"):
continue

if field not in df.columns:
# Need to collect schema to check columns and dtypes
schema = df.collect_schema()
if field not in schema.names():
raise ValueError(f"Field '{field}' not found in DataFrame")

dtype = df[field].dtype
dtype = schema[field]
resolved_values = _resolve_values(values, dtype)
if is_point_selection:
df = df.filter(nw.col(field).is_in(resolved_values))
Expand All @@ -168,7 +178,11 @@ def _filter_dataframe(
f"Invalid selection: {field}={resolved_values}"
)

return nw.to_native(df)
if not is_lazy and is_narwhals_lazyframe(df):
# Undo the lazy
return df.collect().to_native() # type: ignore[no-any-return]

return df.to_native()


def _resolve_values(values: Any, dtype: Any) -> list[Any]:
Expand Down
18 changes: 16 additions & 2 deletions marimo/_plugins/ui/_impl/charts/altair_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
get_table_manager_or_none,
)
from marimo._utils.data_uri import build_data_url
from marimo._utils.narwhals_utils import can_narwhalify
from marimo._utils.narwhals_utils import can_narwhalify, is_narwhals_lazyframe

LOGGER = _loggers.marimo_logger()

Expand Down Expand Up @@ -154,7 +154,16 @@ def sanitize_nan_infs(data: Any) -> Any:
"""Sanitize NaN and Inf values in Dataframes for JSON serialization."""
if can_narwhalify(data):
narwhals_data = nw.from_native(data)
for col, dtype in narwhals_data.schema.items():
is_prev_lazy = isinstance(narwhals_data, nw.LazyFrame)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use the is_narwhals_lazyframe helper here too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep good call. just updated in a followup PR


# Convert to lazy for optimization if not already lazy
if not is_prev_lazy:
narwhals_data = narwhals_data.lazy()

# Get schema without collecting
schema = narwhals_data.collect_schema()

for col, dtype in schema.items():
# Only numeric columns can have NaN or Inf values
if dtype.is_numeric():
narwhals_data = narwhals_data.with_columns(
Expand All @@ -163,6 +172,11 @@ def sanitize_nan_infs(data: Any) -> Any:
.otherwise(nw.col(col))
.name.keep()
)

# Collect if input was eager
if not is_prev_lazy and is_narwhals_lazyframe(narwhals_data):
narwhals_data = narwhals_data.collect()

return narwhals_data.to_native()
return data

Expand Down
5 changes: 4 additions & 1 deletion marimo/_utils/narwhals_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,10 @@ def downgrade_narwhals_df_to_v1(
"""
Downgrade a narwhals dataframe to the latest version.
"""
return nw1.from_native(df.to_native()) # type: ignore[no-any-return]
if is_narwhals_lazyframe(df) or is_narwhals_dataframe(df):
return nw1.from_native(df.to_native()) # type: ignore[no-any-return]
# Pass through
return df


def is_narwhals_lazyframe(df: Any) -> TypeIs[nw.LazyFrame[Any]]:
Expand Down
42 changes: 23 additions & 19 deletions tests/_data/test_preview_column.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from datetime import datetime, time, timedelta
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from unittest.mock import patch

import pytest
Expand All @@ -19,6 +19,7 @@
)
from marimo._runtime.requests import PreviewDatasetColumnRequest
from marimo._utils.platform import is_windows
from tests._data.mocks import create_dataframes
from tests.mocks import snapshotter
from tests.utils import assert_serialize_roundtrip

Expand Down Expand Up @@ -60,13 +61,15 @@ def cleanup() -> Generator[None, None, None]:
not HAS_DF_DEPS, reason="optional dependencies not installed"
)
@pytest.mark.skipif(is_windows(), reason="Windows encodes base64 differently")
def test_get_column_preview_for_dataframe() -> None:
import pandas as pd

@pytest.mark.parametrize(
"df",
create_dataframes(
{"A": [1, 2, 3], "B": ["a", "a", "a"]},
),
)
def test_get_column_preview_for_dataframe(df: Any) -> None:
register_transformers()

df = pd.DataFrame({"A": [1, 2, 3], "B": ["a", "a", "a"]})

for column_name in ["A", "B"]:
result = get_column_preview_for_dataframe(
df,
Expand Down Expand Up @@ -518,9 +521,9 @@ def test_sanitize_dtypes() -> None:


@pytest.mark.skipif(
not DependencyManager.narwhals.has(), reason="narwhals not installed"
not DependencyManager.narwhals.has() or not DependencyManager.polars.has(),
reason="narwhals and polars not installed",
)
@pytest.mark.xfail(reason="Sanitizing is failing") # TODO: Fix this
def test_sanitize_dtypes_enum() -> None:
import narwhals as nw
import polars as pl
Expand All @@ -534,17 +537,16 @@ def test_sanitize_dtypes_enum() -> None:
nw_df = nw.from_native(df)

result = _sanitize_data(nw_df, "enum_col")
assert result.schema["enum_col"] == nw.String
assert result.collect_schema()["enum_col"] == nw.String

lazy_df = nw_df.lazy()
result = _sanitize_data(lazy_df, "enum_col")
assert result.collect_schema()["enum_col"] == nw.String

@pytest.mark.skipif(
not DependencyManager.polars.has(), reason="polars not installed"
)
def test_preview_column_duration_dtype() -> None:
import polars as pl

# Test days conversion
df = pl.DataFrame(
@pytest.mark.parametrize(
"df",
create_dataframes(
{
"duration_weeks": [timedelta(weeks=1), timedelta(weeks=2)],
"duration_days": [timedelta(days=1), timedelta(days=2)],
Expand All @@ -559,9 +561,11 @@ def test_preview_column_duration_dtype() -> None:
timedelta(microseconds=1),
timedelta(microseconds=2),
],
}
)

},
exclude=["pyarrow", "duckdb", "ibis"],
),
)
def test_preview_column_duration_dtype(df) -> None:
for column_name in df.columns:
result = get_column_preview_dataset(
table=get_table_manager(df),
Expand Down
17 changes: 7 additions & 10 deletions tests/_plugins/ui/_impl/charts/test_altair_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
_to_marimo_json,
register_transformers,
)
from tests._data.mocks import create_dataframes
from tests._data.mocks import DFType, create_dataframes

HAS_DEPS = DependencyManager.pandas.has() and DependencyManager.altair.has()

Expand All @@ -30,7 +30,7 @@
@pytest.mark.parametrize(
"df",
create_dataframes(
{"A": [1, 2, 3], "B": ["a", "b", "c"]}, exclude=["duckdb"]
{"A": [1, 2, 3], "B": ["a", "b", "c"]},
),
)
def test_to_marimo_json(df: IntoDataFrame):
Expand All @@ -46,7 +46,7 @@ def test_to_marimo_json(df: IntoDataFrame):
@pytest.mark.parametrize(
"df",
create_dataframes(
{"A": [1, 2, 3], "B": ["a", "b", "c"]}, exclude=["duckdb"]
{"A": [1, 2, 3], "B": ["a", "b", "c"]},
),
)
def test_to_marimo_csv(df: IntoDataFrame):
Expand All @@ -62,7 +62,7 @@ def test_to_marimo_csv(df: IntoDataFrame):
@pytest.mark.parametrize(
"df",
create_dataframes(
{"A": [1, 2, 3], "B": ["a", "b", "c"]}, exclude=["duckdb"]
{"A": [1, 2, 3], "B": ["a", "b", "c"]},
),
)
def test_to_marimo_inline_csv(df: IntoDataFrame):
Expand All @@ -79,7 +79,7 @@ def test_to_marimo_inline_csv(df: IntoDataFrame):
@pytest.mark.parametrize(
"df",
create_dataframes(
{"A": [1, 2, 3], "B": ["a", "b", "c"]}, exclude=["duckdb"]
{"A": [1, 2, 3], "B": ["a", "b", "c"]},
),
)
def test_data_to_json_string(df: IntoDataFrame):
Expand All @@ -95,9 +95,7 @@ def test_data_to_json_string(df: IntoDataFrame):
@pytest.mark.parametrize(
"df",
# We skip pyarrow because it's csv is formatted differently
create_dataframes(
{"A": [1, 2, 3], "B": ["a", "b", "c"]}, exclude=["pyarrow", "duckdb"]
),
create_dataframes({"A": [1, 2, 3], "B": ["a", "b", "c"]}),
)
def test_data_to_csv_string(df: IntoDataFrame):
result = _data_to_csv_string(df)
Expand Down Expand Up @@ -146,7 +144,6 @@ def test_to_marimo_csv_with_missing_values(df: IntoDataFrame):
"df",
create_dataframes(
{"A": range(10000), "B": [f"value_{i}" for i in range(10000)]},
exclude=["pyarrow", "duckdb"],
),
)
def test_to_marimo_inline_csv_large_dataset(df: IntoDataFrame):
Expand Down Expand Up @@ -229,7 +226,7 @@ def test_register_transformers(mock_data_transformers: MagicMock):
)


SUPPORTS_ARROW_IPC = ["pandas", "polars", "lazy-polars"]
SUPPORTS_ARROW_IPC: list[DFType] = ["pandas", "polars", "lazy-polars"]


@pytest.mark.skipif(not HAS_DEPS, reason="optional dependencies not installed")
Expand Down
Loading
Loading