Skip to content

Commit 091c8f9

Browse files
authored
improvement: use lazy-frames for multi-step processing, include more tests (#6608)
After moving to narwhals v2, we can support more dataframe types in certain codepaths and tests. This also makes multi-step operations lazy to avoid incremental df states
1 parent 68c75ff commit 091c8f9

File tree

11 files changed

+138
-85
lines changed

11 files changed

+138
-85
lines changed

CONTRIBUTING.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,19 +157,25 @@ make py-test
157157
Run a specific test
158158

159159
```bash
160-
hatch run test:test tests/_ast/
160+
hatch run +py=3.13 test:test tests/_ast/
161+
```
162+
163+
Run all changed tests
164+
165+
```bash
166+
hatch run +py=3.13 test:test --picked
161167
```
162168

163169
Run tests with optional dependencies
164170

165171
```bash
166-
hatch run test-optional:test tests/_ast/
172+
hatch run +py=3.13 test-optional:test tests/_ast/
167173
```
168174

169-
Run tests with a specific Python version
175+
Run tests across all Python versions (omit `+py`)
170176

171177
```bash
172-
hatch run +py=3.10 test:test tests/_ast/
178+
hatch run test:test tests/_ast/
173179
```
174180

175181
Run all tests across all Python versions

marimo/_data/preview_column.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def _get_altair_chart(
276276
try:
277277
chart_spec = _get_chart_spec(
278278
# Downgrade to v1 since altair doesn't support v2 yet
279-
# This is valiadted with our tests, so if the tests pass with this
279+
# This is validated with our tests, so if the tests pass with this
280280
# removed, we can remove the downgrade.
281281
column_data=downgrade_narwhals_df_to_v1(column_data),
282282
column_type=column_type,
@@ -369,49 +369,49 @@ def _sanitize_data(
369369
total_seconds = diff.total_seconds()
370370
if total_seconds >= 604800:
371371
# Use weeks if range is at least a week
372-
column_data = column_data.with_columns(
372+
column_data = frame.with_columns(
373373
(col.dt.total_seconds() / 604800).alias(
374374
column_name
375375
)
376376
)
377377
elif total_seconds >= 86400:
378378
# Use days if range is at least a day
379-
column_data = column_data.with_columns(
379+
column_data = frame.with_columns(
380380
(col.dt.total_seconds() / 86400).alias(column_name)
381381
)
382382
elif total_seconds >= 3600:
383383
# Use hours if range is at least an hour
384-
column_data = column_data.with_columns(
384+
column_data = frame.with_columns(
385385
(col.dt.total_seconds() / 3600).alias(column_name)
386386
)
387387
elif total_seconds >= 60:
388388
# Use minutes if range is at least a minute
389-
column_data = column_data.with_columns(
389+
column_data = frame.with_columns(
390390
col.dt.total_minutes().alias(column_name)
391391
)
392392
elif total_seconds >= 1:
393393
# Use seconds if range is at least a second
394-
column_data = column_data.with_columns(
394+
column_data = frame.with_columns(
395395
col.dt.total_seconds().alias(column_name)
396396
)
397397
elif total_seconds >= 0.001:
398398
# Use milliseconds if range is at least a millisecond
399-
column_data = column_data.with_columns(
399+
column_data = frame.with_columns(
400400
col.dt.total_milliseconds().alias(column_name)
401401
)
402402
elif total_seconds >= 0.000001:
403403
# Use microseconds if range is at least a microsecond
404-
column_data = column_data.with_columns(
404+
column_data = frame.with_columns(
405405
col.dt.total_microseconds().alias(column_name)
406406
)
407407
elif total_seconds >= 0.000000001:
408408
# Use nanoseconds if range is at least a nanosecond
409-
column_data = column_data.with_columns(
409+
column_data = frame.with_columns(
410410
col.dt.total_nanoseconds().alias(column_name)
411411
)
412412
except Exception as e:
413413
LOGGER.warning("Failed to infer duration precision: %s", e)
414-
column_data = column_data.with_columns(
414+
column_data = frame.with_columns(
415415
col.dt.total_seconds().alias(column_name)
416416
)
417417
except Exception as e:

marimo/_plugins/ui/_impl/altair_chart.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright 2024 Marimo. All rights reserved.
2+
23
from __future__ import annotations
34

45
import datetime
@@ -29,6 +30,7 @@
2930
assert_can_narwhalify,
3031
can_narwhalify,
3132
empty_df,
33+
is_narwhals_lazyframe,
3234
)
3335

3436
if sys.version_info < (3, 10):
@@ -100,7 +102,11 @@ def _using_vegafusion() -> bool:
100102
def _filter_dataframe(
101103
native_df: Union[IntoDataFrame, IntoLazyFrame], selection: ChartSelection
102104
) -> Union[IntoDataFrame, IntoLazyFrame]:
103-
df = nw.from_native(native_df)
105+
# Use lazy evaluation for efficient chained filtering
106+
base = nw.from_native(native_df)
107+
is_lazy = is_narwhals_lazyframe(base)
108+
df = base.lazy()
109+
104110
if not isinstance(selection, dict):
105111
raise TypeError("Input 'selection' must be a dictionary")
106112

@@ -121,10 +127,12 @@ def _filter_dataframe(
121127
vgsid = fields.get("_vgsid_", [])
122128
try:
123129
indexes = [int(i) - 1 for i in vgsid]
124-
df = cast(nw.DataFrame[Any], df)[indexes]
130+
# Need to collect for index-based selection
131+
non_lazy = df.collect()[indexes]
132+
df = non_lazy.lazy()
125133
except IndexError:
126134
# Out of bounds index, return empty dataframe if it's the
127-
df = cast(nw.DataFrame[Any], df)[[]]
135+
df = df.head(0)
128136
LOGGER.error(f"Invalid index in selection: {vgsid}")
129137
except ValueError:
130138
LOGGER.error(f"Invalid index in selection: {vgsid}")
@@ -140,10 +148,12 @@ def _filter_dataframe(
140148
if field in ("vlPoint", "_vgsid_"):
141149
continue
142150

143-
if field not in df.columns:
151+
# Need to collect schema to check columns and dtypes
152+
schema = df.collect_schema()
153+
if field not in schema.names():
144154
raise ValueError(f"Field '{field}' not found in DataFrame")
145155

146-
dtype = df[field].dtype
156+
dtype = schema[field]
147157
resolved_values = _resolve_values(values, dtype)
148158
if is_point_selection:
149159
df = df.filter(nw.col(field).is_in(resolved_values))
@@ -168,7 +178,11 @@ def _filter_dataframe(
168178
f"Invalid selection: {field}={resolved_values}"
169179
)
170180

171-
return nw.to_native(df)
181+
if not is_lazy and is_narwhals_lazyframe(df):
182+
# Undo the lazy
183+
return df.collect().to_native() # type: ignore[no-any-return]
184+
185+
return df.to_native()
172186

173187

174188
def _resolve_values(values: Any, dtype: Any) -> list[Any]:

marimo/_plugins/ui/_impl/charts/altair_transformer.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
get_table_manager_or_none,
1616
)
1717
from marimo._utils.data_uri import build_data_url
18-
from marimo._utils.narwhals_utils import can_narwhalify
18+
from marimo._utils.narwhals_utils import can_narwhalify, is_narwhals_lazyframe
1919

2020
LOGGER = _loggers.marimo_logger()
2121

@@ -154,7 +154,16 @@ def sanitize_nan_infs(data: Any) -> Any:
154154
"""Sanitize NaN and Inf values in Dataframes for JSON serialization."""
155155
if can_narwhalify(data):
156156
narwhals_data = nw.from_native(data)
157-
for col, dtype in narwhals_data.schema.items():
157+
is_prev_lazy = isinstance(narwhals_data, nw.LazyFrame)
158+
159+
# Convert to lazy for optimization if not already lazy
160+
if not is_prev_lazy:
161+
narwhals_data = narwhals_data.lazy()
162+
163+
# Get schema without collecting
164+
schema = narwhals_data.collect_schema()
165+
166+
for col, dtype in schema.items():
158167
# Only numeric columns can have NaN or Inf values
159168
if dtype.is_numeric():
160169
narwhals_data = narwhals_data.with_columns(
@@ -163,6 +172,11 @@ def sanitize_nan_infs(data: Any) -> Any:
163172
.otherwise(nw.col(col))
164173
.name.keep()
165174
)
175+
176+
# Collect if input was eager
177+
if not is_prev_lazy and is_narwhals_lazyframe(narwhals_data):
178+
narwhals_data = narwhals_data.collect()
179+
166180
return narwhals_data.to_native()
167181
return data
168182

marimo/_utils/narwhals_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,10 @@ def downgrade_narwhals_df_to_v1(
224224
"""
225225
Downgrade a narwhals dataframe to the latest version.
226226
"""
227-
return nw1.from_native(df.to_native()) # type: ignore[no-any-return]
227+
if is_narwhals_lazyframe(df) or is_narwhals_dataframe(df):
228+
return nw1.from_native(df.to_native()) # type: ignore[no-any-return]
229+
# Pass through
230+
return df
228231

229232

230233
def is_narwhals_lazyframe(df: Any) -> TypeIs[nw.LazyFrame[Any]]:

tests/_data/test_preview_column.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from datetime import datetime, time, timedelta
4-
from typing import TYPE_CHECKING
4+
from typing import TYPE_CHECKING, Any
55
from unittest.mock import patch
66

77
import pytest
@@ -19,6 +19,7 @@
1919
)
2020
from marimo._runtime.requests import PreviewDatasetColumnRequest
2121
from marimo._utils.platform import is_windows
22+
from tests._data.mocks import create_dataframes
2223
from tests.mocks import snapshotter
2324
from tests.utils import assert_serialize_roundtrip
2425

@@ -60,13 +61,15 @@ def cleanup() -> Generator[None, None, None]:
6061
not HAS_DF_DEPS, reason="optional dependencies not installed"
6162
)
6263
@pytest.mark.skipif(is_windows(), reason="Windows encodes base64 differently")
63-
def test_get_column_preview_for_dataframe() -> None:
64-
import pandas as pd
65-
64+
@pytest.mark.parametrize(
65+
"df",
66+
create_dataframes(
67+
{"A": [1, 2, 3], "B": ["a", "a", "a"]},
68+
),
69+
)
70+
def test_get_column_preview_for_dataframe(df: Any) -> None:
6671
register_transformers()
6772

68-
df = pd.DataFrame({"A": [1, 2, 3], "B": ["a", "a", "a"]})
69-
7073
for column_name in ["A", "B"]:
7174
result = get_column_preview_for_dataframe(
7275
df,
@@ -518,9 +521,9 @@ def test_sanitize_dtypes() -> None:
518521

519522

520523
@pytest.mark.skipif(
521-
not DependencyManager.narwhals.has(), reason="narwhals not installed"
524+
not DependencyManager.narwhals.has() or not DependencyManager.polars.has(),
525+
reason="narwhals and polars not installed",
522526
)
523-
@pytest.mark.xfail(reason="Sanitizing is failing") # TODO: Fix this
524527
def test_sanitize_dtypes_enum() -> None:
525528
import narwhals as nw
526529
import polars as pl
@@ -534,17 +537,16 @@ def test_sanitize_dtypes_enum() -> None:
534537
nw_df = nw.from_native(df)
535538

536539
result = _sanitize_data(nw_df, "enum_col")
537-
assert result.schema["enum_col"] == nw.String
540+
assert result.collect_schema()["enum_col"] == nw.String
538541

542+
lazy_df = nw_df.lazy()
543+
result = _sanitize_data(lazy_df, "enum_col")
544+
assert result.collect_schema()["enum_col"] == nw.String
539545

540-
@pytest.mark.skipif(
541-
not DependencyManager.polars.has(), reason="polars not installed"
542-
)
543-
def test_preview_column_duration_dtype() -> None:
544-
import polars as pl
545546

546-
# Test days conversion
547-
df = pl.DataFrame(
547+
@pytest.mark.parametrize(
548+
"df",
549+
create_dataframes(
548550
{
549551
"duration_weeks": [timedelta(weeks=1), timedelta(weeks=2)],
550552
"duration_days": [timedelta(days=1), timedelta(days=2)],
@@ -559,9 +561,11 @@ def test_preview_column_duration_dtype() -> None:
559561
timedelta(microseconds=1),
560562
timedelta(microseconds=2),
561563
],
562-
}
563-
)
564-
564+
},
565+
exclude=["pyarrow", "duckdb", "ibis"],
566+
),
567+
)
568+
def test_preview_column_duration_dtype(df) -> None:
565569
for column_name in df.columns:
566570
result = get_column_preview_dataset(
567571
table=get_table_manager(df),

tests/_plugins/ui/_impl/charts/test_altair_transformers.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
_to_marimo_json,
1919
register_transformers,
2020
)
21-
from tests._data.mocks import create_dataframes
21+
from tests._data.mocks import DFType, create_dataframes
2222

2323
HAS_DEPS = DependencyManager.pandas.has() and DependencyManager.altair.has()
2424

@@ -30,7 +30,7 @@
3030
@pytest.mark.parametrize(
3131
"df",
3232
create_dataframes(
33-
{"A": [1, 2, 3], "B": ["a", "b", "c"]}, exclude=["duckdb"]
33+
{"A": [1, 2, 3], "B": ["a", "b", "c"]},
3434
),
3535
)
3636
def test_to_marimo_json(df: IntoDataFrame):
@@ -46,7 +46,7 @@ def test_to_marimo_json(df: IntoDataFrame):
4646
@pytest.mark.parametrize(
4747
"df",
4848
create_dataframes(
49-
{"A": [1, 2, 3], "B": ["a", "b", "c"]}, exclude=["duckdb"]
49+
{"A": [1, 2, 3], "B": ["a", "b", "c"]},
5050
),
5151
)
5252
def test_to_marimo_csv(df: IntoDataFrame):
@@ -62,7 +62,7 @@ def test_to_marimo_csv(df: IntoDataFrame):
6262
@pytest.mark.parametrize(
6363
"df",
6464
create_dataframes(
65-
{"A": [1, 2, 3], "B": ["a", "b", "c"]}, exclude=["duckdb"]
65+
{"A": [1, 2, 3], "B": ["a", "b", "c"]},
6666
),
6767
)
6868
def test_to_marimo_inline_csv(df: IntoDataFrame):
@@ -79,7 +79,7 @@ def test_to_marimo_inline_csv(df: IntoDataFrame):
7979
@pytest.mark.parametrize(
8080
"df",
8181
create_dataframes(
82-
{"A": [1, 2, 3], "B": ["a", "b", "c"]}, exclude=["duckdb"]
82+
{"A": [1, 2, 3], "B": ["a", "b", "c"]},
8383
),
8484
)
8585
def test_data_to_json_string(df: IntoDataFrame):
@@ -95,9 +95,7 @@ def test_data_to_json_string(df: IntoDataFrame):
9595
@pytest.mark.parametrize(
9696
"df",
9797
# We skip pyarrow because it's csv is formatted differently
98-
create_dataframes(
99-
{"A": [1, 2, 3], "B": ["a", "b", "c"]}, exclude=["pyarrow", "duckdb"]
100-
),
98+
create_dataframes({"A": [1, 2, 3], "B": ["a", "b", "c"]}),
10199
)
102100
def test_data_to_csv_string(df: IntoDataFrame):
103101
result = _data_to_csv_string(df)
@@ -146,7 +144,6 @@ def test_to_marimo_csv_with_missing_values(df: IntoDataFrame):
146144
"df",
147145
create_dataframes(
148146
{"A": range(10000), "B": [f"value_{i}" for i in range(10000)]},
149-
exclude=["pyarrow", "duckdb"],
150147
),
151148
)
152149
def test_to_marimo_inline_csv_large_dataset(df: IntoDataFrame):
@@ -229,7 +226,7 @@ def test_register_transformers(mock_data_transformers: MagicMock):
229226
)
230227

231228

232-
SUPPORTS_ARROW_IPC = ["pandas", "polars", "lazy-polars"]
229+
SUPPORTS_ARROW_IPC: list[DFType] = ["pandas", "polars", "lazy-polars"]
233230

234231

235232
@pytest.mark.skipif(not HAS_DEPS, reason="optional dependencies not installed")

0 commit comments

Comments
 (0)