Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions marimo/_plugins/ui/_impl/dataframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
Callable,
Final,
Expand All @@ -13,6 +14,8 @@
Union,
)

import narwhals.stable.v2 as nw

from marimo._output.rich_help import mddoc
from marimo._plugins.ui._core.ui_element import UIElement
from marimo._plugins.ui._impl.dataframes.transforms.apply import (
Expand Down Expand Up @@ -44,8 +47,12 @@
)
from marimo._runtime.functions import EmptyArgs, Function
from marimo._utils.memoize import memoize_last_value
from marimo._utils.narwhals_utils import is_narwhals_lazyframe
from marimo._utils.parse_dataclass import parse_raw

if TYPE_CHECKING:
from narwhals.typing import IntoLazyFrame


@dataclass
class GetDataFrameResponse:
Expand Down Expand Up @@ -86,7 +93,7 @@ def __init__(self, error: str):
class dataframe(UIElement[dict[str, Any], DataFrameType]):
"""Run transformations on a DataFrame or series.

Currently only Pandas or Polars DataFrames are supported.
Currently supports Pandas, Polars, Ibis, Pyarrow, and DuckDB.

Examples:
```python
Expand Down Expand Up @@ -138,14 +145,17 @@ def __init__(
except Exception:
pass

# Make the dataframe lazy and keep track of whether it was lazy originally
nw_df: nw.LazyFrame[Any] = nw.from_native(df, pass_through=False)
self._was_lazy = is_narwhals_lazyframe(nw_df)
nw_df = nw_df.lazy()

self._limit = limit
self._dataframe_name = dataframe_name
self._data = df
self._handler = handler
self._manager = self._get_cached_table_manager(df, self._limit)
self._transform_container = TransformsContainer[DataFrameType](
df, handler
)
self._transform_container = TransformsContainer(nw_df, handler)
self._error: Optional[str] = None
self._last_transforms = Transformations([])
self._page_size = page_size or 5 # Default to 5 rows (.head())
Expand Down Expand Up @@ -210,12 +220,14 @@ def _get_dataframe(self, _args: EmptyArgs) -> GetDataFrameResponse:
row_headers=manager.get_row_headers(),
field_types=manager.get_field_types(),
python_code=self._handler.as_python_code(
self._transform_container._snapshot_df,
self._dataframe_name,
# manager.get_column_names(),
self._manager.get_column_names(),
self._last_transforms.transforms,
),
sql_code=self._handler.as_sql_code(manager.data),
sql_code=self._handler.as_sql_code(
self._transform_container._snapshot_df
),
)

def _get_column_values(
Expand Down Expand Up @@ -245,19 +257,19 @@ def _get_column_values(
def _convert_value(self, value: dict[str, Any]) -> DataFrameType:
if value is None:
self._error = None
return self._data
return _maybe_collect(self._data, self._was_lazy)

try:
transformations = parse_raw(value, Transformations)
result = self._transform_container.apply(transformations)
self._error = None
self._last_transforms = transformations
return result
return _maybe_collect(result, self._was_lazy)
except Exception as e:
error = f"Error applying dataframe transform: {str(e)}\n\n"
sys.stderr.write(error)
self._error = error
return self._data
return _maybe_collect(self._data, self._was_lazy)

def _search(self, args: SearchTableArgs) -> SearchTableResponse:
offset = args.page_number * args.page_size
Expand Down Expand Up @@ -304,7 +316,7 @@ def _apply_filters_query_sort(
self,
query: Optional[str],
sort: Optional[list[SortArgs]],
) -> TableManager[Any]:
) -> TableManager[DataFrameType]:
result = self._get_cached_table_manager(self._value, self._limit)

if query:
Expand All @@ -320,9 +332,17 @@ def _apply_filters_query_sort(

@memoize_last_value
def _get_cached_table_manager(
self, value: Any, limit: Optional[int]
) -> TableManager[Any]:
self, value: DataFrameType, limit: Optional[int]
) -> TableManager[DataFrameType]:
tm = get_table_manager(value)
if limit is not None:
tm = tm.take(limit, 0)
return tm


def _maybe_collect(
df: nw.LazyFrame[IntoLazyFrame], was_lazy: bool
) -> DataFrameType:
if was_lazy:
return df.collect().to_native()
return df.to_native()
59 changes: 23 additions & 36 deletions marimo/_plugins/ui/_impl/dataframes/transforms/apply.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations

from typing import Any, Generic, TypeVar
from typing import TYPE_CHECKING, TypeVar

from narwhals.dependencies import is_narwhals_dataframe

from marimo._dependencies.dependencies import DependencyManager
from marimo._plugins.ui._impl.dataframes.transforms.handlers import (
IbisTransformHandler,
PandasTransformHandler,
PolarsTransformHandler,
NarwhalsTransformHandler,
)
from marimo._plugins.ui._impl.dataframes.transforms.types import (
DataFrameType,
Transform,
Transformations,
TransformHandler,
TransformType,
)
from marimo._utils.assert_never import assert_never
from marimo._utils.narwhals_utils import can_narwhalify

T = TypeVar("T")


if TYPE_CHECKING:
import narwhals.stable.v2 as nw
from narwhals.typing import IntoLazyFrame


def _handle(df: T, handler: TransformHandler[T], transform: Transform) -> T:
if transform.type is TransformType.COLUMN_CONVERSION:
return handler.handle_column_conversion(df, transform)
Expand Down Expand Up @@ -61,54 +63,39 @@ def _apply_transforms(


def get_handler_for_dataframe(
df: Any,
) -> TransformHandler[Any]:
df: DataFrameType,
) -> NarwhalsTransformHandler:
"""
Gets the handler for the given dataframe.

raises ValueError if the dataframe type is not supported.
"""
if DependencyManager.pandas.imported():
import pandas as pd

if isinstance(df, pd.DataFrame):
return PandasTransformHandler()
if DependencyManager.polars.imported():
import polars as pl

if isinstance(df, pl.DataFrame):
return PolarsTransformHandler()

if DependencyManager.ibis.imported():
import ibis # type: ignore

if isinstance(df, ibis.Table):
return IbisTransformHandler()

if DependencyManager.narwhals.imported():
if is_narwhals_dataframe(df):
return get_handler_for_dataframe(df.to_native())
if not can_narwhalify(df):
raise ValueError(
f"Unsupported dataframe type. Must be Pandas, Polars, Ibis, Pyarrow, or DuckDB. Got: {type(df)}"
)

raise ValueError(
"Unsupported dataframe type. Must be Pandas or Polars."
f" Got: {type(df)}"
)
return NarwhalsTransformHandler()


class TransformsContainer(Generic[T]):
class TransformsContainer:
"""
Keeps internal state of the last transformation applied to the dataframe.
So that we can incrementally apply transformations.
"""

def __init__(self, df: T, handler: TransformHandler[T]) -> None:
def __init__(
self,
df: nw.LazyFrame[IntoLazyFrame],
handler: NarwhalsTransformHandler,
) -> None:
self._original_df = df
# The dataframe for the given transform.
self._snapshot_df = df
self._handler = handler
self._transforms: list[Transform] = []

def apply(self, transform: Transformations) -> T:
def apply(self, transform: Transformations) -> nw.LazyFrame[IntoLazyFrame]:
"""
Applies the given transformations to the dataframe.
"""
Expand Down
Loading
Loading