Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions marimo/_data/charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from textwrap import dedent
from typing import TYPE_CHECKING, Any, Literal, Optional, cast

import narwhals.stable.v1 as nw
import narwhals.stable.v2 as nw

from marimo._data.models import DataType
from marimo._utils import assert_never
Expand Down Expand Up @@ -351,7 +351,9 @@ def _guess_date_format(
if not can_narwhalify(data, eager_only=True):
return self.DEFAULT_DATE_FORMAT, self.DEFAULT_TIME_UNIT

df = nw.from_native(data, eager_only=True)
df: nw.DataFrame[Any] = nw.from_native(
data, pass_through=True, eager_only=True
)

# Get min and max dates using narwhals
min_date = df[column].min()
Expand Down
8 changes: 6 additions & 2 deletions marimo/_data/preview_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing import Any, Optional

import narwhals.stable.v1 as nw
import narwhals.stable.v2 as nw

from marimo import _loggers
from marimo._data.charts import get_chart_builder
Expand All @@ -21,6 +21,7 @@
from marimo._plugins.ui._impl.tables.utils import get_table_manager_or_none
from marimo._runtime.requests import PreviewDatasetColumnRequest
from marimo._sql.utils import wrapped_sql
from marimo._utils.narwhals_utils import downgrade_narwhals_df_to_v1

LOGGER = _loggers.marimo_logger()

Expand Down Expand Up @@ -274,7 +275,10 @@ def _get_altair_chart(
# We may not know number of rows, so we can check for max rows error
try:
chart_spec = _get_chart_spec(
column_data=column_data,
# Downgrade to v1 since altair doesn't support v2 yet
# This is valiadted with our tests, so if the tests pass with this
Copy link

Copilot AI Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in comment: 'valiadted' should be 'validated'

Suggested change
# This is valiadted with our tests, so if the tests pass with this
# This is validated with our tests, so if the tests pass with this

Copilot uses AI. Check for mistakes.
# removed, we can remove the downgrade.
column_data=downgrade_narwhals_df_to_v1(column_data),
column_type=column_type,
column_name=column_name,
should_limit_to_10_items=should_limit_to_10_items,
Expand Down
10 changes: 5 additions & 5 deletions marimo/_data/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from typing import Any, cast

import narwhals.stable.v1 as nw
import narwhals.stable.v2 as nw
from narwhals.typing import IntoSeries

from marimo._utils.narwhals_utils import (
Expand Down Expand Up @@ -54,7 +54,7 @@ def _get_name(series: nw.Series) -> str:
return str(series.name)


@nw.narwhalify(eager_or_interchange_only=True, series_only=True)
@nw.narwhalify(eager_only=True, series_only=True)
def get_number_series_info(series: nw.Series) -> NumberSeriesInfo:
"""
Get the summary of a numeric series.
Expand All @@ -77,7 +77,7 @@ def validate_number(value: Any) -> float:
)


@nw.narwhalify(eager_or_interchange_only=True, series_only=True)
@nw.narwhalify(eager_only=True, series_only=True)
def get_category_series_info(series: nw.Series) -> CategorySeriesInfo:
"""
Get the summary of a categorical series.
Expand All @@ -92,7 +92,7 @@ def get_category_series_info(series: nw.Series) -> CategorySeriesInfo:
)


@nw.narwhalify(eager_or_interchange_only=True, series_only=True)
@nw.narwhalify(eager_only=True, series_only=True)
def get_date_series_info(series: nw.Series) -> DateSeriesInfo:
"""
Get the summary of a date series.
Expand All @@ -116,7 +116,7 @@ def validate_date(value: Any) -> str:
)


@nw.narwhalify(eager_or_interchange_only=True, series_only=True)
@nw.narwhalify(eager_only=True, series_only=True)
def get_datetime_series_info(series: nw.Series) -> DateSeriesInfo:
"""
Get the summary of a datetime series.
Expand Down
2 changes: 1 addition & 1 deletion marimo/_output/formatters/df_formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from enum import Enum
from typing import Any

import narwhals.stable.v1 as nw
import narwhals.stable.v2 as nw

from marimo import _loggers
from marimo._messaging.mimetypes import KnownMimeType
Expand Down
6 changes: 4 additions & 2 deletions marimo/_plugins/core/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from urllib.parse import urlparse

import narwhals.stable.v1 as nw
import narwhals.stable.v2 as nw

from marimo._dependencies.dependencies import DependencyManager
from marimo._utils.narwhals_utils import can_narwhalify
Expand Down Expand Up @@ -129,7 +129,9 @@ def io_to_data_url(

# Handle Pandas DataFrames (convert to CSV)
if can_narwhalify(src, eager_only=True):
df = nw.from_native(src, pass_through=False, eager_only=True)
df: nw.DataFrame[Any] = nw.from_native(
src, pass_through=True, eager_only=True
)
file = io.BytesIO()
df.write_csv(file)
file.seek(0)
Expand Down
23 changes: 10 additions & 13 deletions marimo/_plugins/ui/_impl/altair_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
cast,
)

import narwhals.stable.v1 as nw
from narwhals.typing import IntoDataFrame
import narwhals.stable.v2 as nw
from narwhals.typing import IntoDataFrame, IntoLazyFrame

from marimo import _loggers
from marimo._dependencies.dependencies import DependencyManager
Expand Down Expand Up @@ -54,7 +54,9 @@
RowOrientedData = list[dict[str, Any]]
ColumnOrientedData = dict[str, list[Any]]

ChartDataType = Union[IntoDataFrame, RowOrientedData, ColumnOrientedData]
ChartDataType = Union[
IntoDataFrame, IntoLazyFrame, RowOrientedData, ColumnOrientedData
]

# Union of all possible chart types
AltairChartType: TypeAlias = "altair.vegalite.v5.api.ChartType"
Expand Down Expand Up @@ -96,8 +98,8 @@ def _using_vegafusion() -> bool:


def _filter_dataframe(
native_df: IntoDataFrame, selection: ChartSelection
) -> IntoDataFrame:
native_df: Union[IntoDataFrame, IntoLazyFrame], selection: ChartSelection
) -> Union[IntoDataFrame, IntoLazyFrame]:
df = nw.from_native(native_df)
if not isinstance(selection, dict):
raise TypeError("Input 'selection' must be a dictionary")
Expand Down Expand Up @@ -182,14 +184,9 @@ def _coerce_value(value: Any, dtype: Any) -> Any:
if nw.Datetime == dtype and isinstance(dtype, nw.Datetime):
if isinstance(value, str):
res = datetime.datetime.fromisoformat(value)
# If dtype has no timezone, shift by local timezone offset
if dtype.time_zone is None:
local_tz = datetime.datetime.now().astimezone().tzinfo
LOGGER.warning(
f"Datetime was given with a timezone when not expected. "
f"Shifting by local timezone offset {local_tz}."
)
return res.astimezone(local_tz).replace(tzinfo=None)
# If dtype has no timezone, but value has timezone, remove timezone without shifting
if dtype.time_zone is None and res.tzinfo is not None:
return res.replace(tzinfo=None)
Comment on lines +187 to +189
Copy link

Copilot AI Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing timezone information without shifting could lead to data inconsistency. The original logic with timezone shifting was more robust. Consider adding a comment explaining why this approach was chosen or validate that this change is intentional.

Suggested change
# If dtype has no timezone, but value has timezone, remove timezone without shifting
if dtype.time_zone is None and res.tzinfo is not None:
return res.replace(tzinfo=None)
# If dtype has no timezone, but value has timezone, shift to UTC before removing timezone info
if dtype.time_zone is None and res.tzinfo is not None:
# Shift to UTC before dropping tzinfo to avoid data inconsistency
return res.astimezone(datetime.timezone.utc).replace(tzinfo=None)

Copilot uses AI. Check for mistakes.
return res

# Value is milliseconds since epoch
Expand Down
6 changes: 4 additions & 2 deletions marimo/_plugins/ui/_impl/charts/altair_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import base64
from typing import Any, Literal, TypedDict, Union

import narwhals.stable.v1 as nw
import narwhals.stable.v2 as nw
from narwhals.typing import IntoDataFrame

import marimo._output.data.data as mo_data
Expand Down Expand Up @@ -137,7 +137,9 @@ def _maybe_sanitize_dataframe(data: Any) -> Any:
):
narwhals_data = nw.from_native(data)
try:
res: nw.DataFrame[Any] = alt.utils.sanitize_narwhals_dataframe(
import narwhals.stable.v1 as nw1

res: nw1.DataFrame[Any] = alt.utils.sanitize_narwhals_dataframe(
narwhals_data # type: ignore[arg-type]
)
return res.to_native() # type: ignore[return-value]
Expand Down
4 changes: 2 additions & 2 deletions marimo/_plugins/ui/_impl/data_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
cast,
)

import narwhals.stable.v1 as nw
import narwhals.stable.v2 as nw
from narwhals.typing import IntoDataFrame

import marimo._output.data.data as mo_data
Expand Down Expand Up @@ -303,7 +303,7 @@ def _apply_edits_row_oriented(
def _apply_edits_dataframe(
native_df: IntoDataFrame, edits: DataEdits, schema: Optional[nw.Schema]
) -> IntoDataFrame:
df = nw.from_native(native_df, eager_or_interchange_only=True)
df = nw.from_native(native_df, eager_only=True)
column_oriented = df.to_dict(as_series=False)
schema = schema or cast(nw.Schema, df.schema)

Expand Down
6 changes: 3 additions & 3 deletions marimo/_plugins/ui/_impl/dataframes/transforms/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from typing import Any, Generic, TypeVar

from narwhals.dependencies import is_narwhals_dataframe

from marimo._dependencies.dependencies import DependencyManager
from marimo._plugins.ui._impl.dataframes.transforms.handlers import (
IbisTransformHandler,
Expand Down Expand Up @@ -84,9 +86,7 @@ def get_handler_for_dataframe(
return IbisTransformHandler()

if DependencyManager.narwhals.imported():
import narwhals as nw

if isinstance(df, nw.DataFrame):
if is_narwhals_dataframe(df):
return get_handler_for_dataframe(df.to_native())

raise ValueError(
Expand Down
38 changes: 23 additions & 15 deletions marimo/_plugins/ui/_impl/tables/narwhals_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from typing import Any, Optional, Union, cast

import msgspec
import narwhals.stable.v1 as nw
from narwhals.stable.v1.typing import IntoFrameT
import narwhals.stable.v2 as nw
from narwhals.typing import IntoDataFrameT, IntoLazyFrameT

from marimo import _loggers
from marimo._data.models import BinValue, ColumnStats, ExternalDataType
Expand All @@ -32,36 +32,37 @@
from marimo._utils.narwhals_utils import (
can_narwhalify,
dataframe_to_csv,
downgrade_narwhals_df_to_v1,
is_narwhals_integer_type,
is_narwhals_lazyframe,
is_narwhals_string_type,
is_narwhals_temporal_type,
is_narwhals_time_type,
unwrap_py_scalar,
upgrade_narwhals_df,
)

LOGGER = _loggers.marimo_logger()
UNSTABLE_API_WARNING = "`Series.hist` is being called from the stable API although considered an unstable feature."


class NarwhalsTableManager(
TableManager[Union[nw.DataFrame[IntoFrameT], nw.LazyFrame[IntoFrameT]]]
TableManager[
Union[nw.DataFrame[IntoDataFrameT], nw.LazyFrame[IntoLazyFrameT]]
]
):
type = "narwhals"

@staticmethod
def from_dataframe(data: IntoFrameT) -> NarwhalsTableManager[IntoFrameT]:
def from_dataframe(
data: Union[IntoDataFrameT, IntoLazyFrameT],
) -> NarwhalsTableManager[IntoDataFrameT, IntoLazyFrameT]:
return NarwhalsTableManager(nw.from_native(data, pass_through=False))

def as_frame(self) -> nw.DataFrame[Any]:
if is_narwhals_lazyframe(self.data):
return self.data.collect()
return self.data

def upgrade(self) -> NarwhalsTableManager[Any]:
return NarwhalsTableManager(upgrade_narwhals_df(self.data))

def as_lazy_frame(self) -> nw.LazyFrame[Any]:
if is_narwhals_lazyframe(self.data):
return self.data
Expand All @@ -86,7 +87,7 @@ def to_csv_str(
def to_json_str(
self, format_mapping: Optional[FormatMapping] = None
) -> str:
frame = self.upgrade().apply_formatting(format_mapping).as_frame()
frame = self.apply_formatting(format_mapping).as_frame()
return sanitize_json_bigint(frame.rows(named=True))

def to_parquet(self) -> bytes:
Expand All @@ -96,11 +97,11 @@ def to_parquet(self) -> bytes:

def apply_formatting(
self, format_mapping: Optional[FormatMapping]
) -> NarwhalsTableManager[Any]:
) -> NarwhalsTableManager[IntoDataFrameT, IntoLazyFrameT]:
if not format_mapping:
return self

frame = self.upgrade().as_frame()
frame = self.as_frame()
_data = frame.to_dict(as_series=False).copy()
for col in _data.keys():
if col in format_mapping:
Expand All @@ -114,7 +115,9 @@ def apply_formatting(
def supports_filters(self) -> bool:
return True

def select_rows(self, indices: list[int]) -> TableManager[Any]:
def select_rows(
self, indices: list[int]
) -> TableManager[Union[IntoDataFrameT, IntoLazyFrameT]]:
if not indices:
return self.with_new_data(self.data.head(0))

Expand Down Expand Up @@ -456,7 +459,10 @@ def get_bin_values(self, column: str, num_bins: int) -> list[BinValue]:
if not dtype.is_numeric():
return []

col = self.as_frame().get_column(column)
# Downgrade to v1 since v2 does not support the hist() method yet
downgraded_df = downgrade_narwhals_df_to_v1(self.as_frame())
col = downgraded_df.get_column(column)

bin_start = col.min()
bin_values: list[BinValue] = []

Expand Down Expand Up @@ -484,10 +490,12 @@ def _get_bin_values_temporal(
nw.hist does not support temporal columns, so we convert to numeric
and then convert back to temporal values.
"""
# Convert to timestamp in ms
col = self.as_frame().get_column(column)
# Downgrade to v1 since v2 does not support the hist() method yet
downgraded_df = downgrade_narwhals_df_to_v1(self.as_frame())
col = downgraded_df.get_column(column)

if dtype == nw.Time:
# Convert to timestamp in ms
col_in_ms = (
col.dt.hour().cast(nw.Int64) * 3600000
+ col.dt.minute().cast(nw.Int64) * 60000
Expand Down
4 changes: 2 additions & 2 deletions marimo/_plugins/ui/_impl/tables/pandas_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any, Optional

import narwhals.stable.v1 as nw
import narwhals.stable.v2 as nw

from marimo import _loggers
from marimo._data.models import ExternalDataType
Expand Down Expand Up @@ -57,7 +57,7 @@ def package_name() -> str:
def create() -> type[TableManager[Any]]:
import pandas as pd

class PandasTableManager(NarwhalsTableManager[pd.DataFrame]):
class PandasTableManager(NarwhalsTableManager[pd.DataFrame, Any]):
type = "pandas"

def __init__(self, data: pd.DataFrame) -> None:
Expand Down
4 changes: 2 additions & 2 deletions marimo/_plugins/ui/_impl/tables/polars_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from functools import cached_property
from typing import Any, Optional, Union

import narwhals.stable.v1 as nw
import narwhals.stable.v2 as nw

from marimo import _loggers
from marimo._data.models import (
Expand Down Expand Up @@ -38,7 +38,7 @@ def create() -> type[TableManager[Any]]:
import polars as pl

class PolarsTableManager(
NarwhalsTableManager[Union[pl.DataFrame, pl.LazyFrame]]
NarwhalsTableManager[pl.DataFrame, pl.LazyFrame]
):
type = "polars"

Expand Down
2 changes: 1 addition & 1 deletion marimo/_plugins/ui/_impl/tables/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing import TypeVar, cast

import narwhals.stable.v1 as nw
import narwhals.stable.v2 as nw
from narwhals.typing import IntoDataFrame

INDEX_COLUMN_NAME = "_marimo_row_id"
Expand Down
Loading
Loading