Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ interface Data {
label?: string | null;
columns: ColumnDataTypes;
pageSize: number;
showDownload: boolean;
}

// eslint-disable-next-line @typescript-eslint/consistent-type-definitions
Expand Down Expand Up @@ -67,6 +68,7 @@ type PluginFunctions = {
data: TableData<T>;
total_rows: number;
}>;
download_as: (req: { format: "csv" | "json" | "parquet" }) => Promise<string>;
};

// Value is selection, but it is not currently exposed to the user
Expand All @@ -77,13 +79,14 @@ export const DataFramePlugin = createPlugin<S>("marimo-dataframe")
z.object({
label: z.string().nullish(),
pageSize: z.number().default(5),
showDownload: z.boolean().default(true),
columns: z
.array(z.tuple([z.string().or(z.number()), z.string(), z.string()]))
.transform((value) => {
const map = new Map<ColumnId, string>();
value.forEach(([key, dataType]) =>
map.set(key as ColumnId, dataType as DataType),
);
value.forEach(([key, dataType]) => {
map.set(key as ColumnId, dataType as DataType);
});
return map;
}),
}),
Expand Down Expand Up @@ -124,6 +127,13 @@ export const DataFramePlugin = createPlugin<S>("marimo-dataframe")
total_rows: z.number(),
}),
),
download_as: rpc
.input(
z.object({
format: z.enum(["csv", "json", "parquet"]),
}),
)
.output(z.string()),
})
.renderer((props) => (
<TableProviders>
Expand All @@ -141,6 +151,8 @@ interface DataTableProps extends Data, PluginFunctions {
value: S;
setValue: (value: S) => void;
host: HTMLElement;
showDownload: boolean;
download_as: (req: { format: "csv" | "json" | "parquet" }) => Promise<string>;
}

const EMPTY: Transformations = {
Expand All @@ -151,11 +163,13 @@ export const DataFrameComponent = memo(
({
columns,
pageSize,
showDownload,
value,
setValue,
get_dataframe,
get_column_values,
search,
download_as,
host,
}: DataTableProps): JSX.Element => {
const { data, error, isPending } = useAsyncData(
Expand Down Expand Up @@ -270,8 +284,8 @@ export const DataFrameComponent = memo(
pagination={true}
fieldTypes={field_types}
rowHeaders={row_headers || Arrays.EMPTY}
showDownload={false}
download_as={Functions.THROW}
showDownload={showDownload}
download_as={download_as}
enableSearch={false}
showFilters={false}
search={search}
Expand Down
2 changes: 2 additions & 0 deletions frontend/src/stories/dataframe.stories.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ export const DataFrame: StoryObj = {
get_dataframe={() => Promise.reject(new Error("not implemented"))}
search={Functions.THROW}
host={document.body}
showDownload={false}
download_as={async () => ""}
/>
);
},
Expand Down
45 changes: 45 additions & 0 deletions marimo/_plugins/ui/_impl/dataframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Union,
)

import marimo._output.data.data as mo_data
from marimo._output.rich_help import mddoc
from marimo._plugins.ui._core.ui_element import UIElement
from marimo._plugins.ui._impl.dataframes.transforms.apply import (
Expand All @@ -24,6 +25,7 @@
Transformations,
)
from marimo._plugins.ui._impl.table import (
DownloadAsArgs,
SearchTableArgs,
SearchTableResponse,
SortArgs,
Expand Down Expand Up @@ -101,6 +103,8 @@ class dataframe(UIElement[dict[str, Any], DataFrameType]):
limit (Optional[int], optional): The number of items to load into memory, in case
the data is remote and lazily fetched. This is likely true for SQL-backed
dataframes via Ibis.
show_download (bool, optional): Whether to show the download button.
Defaults to True.
on_change (Optional[Callable[[DataFrameType], None]], optional): Optional callback
to run when this element's value changes.
"""
Expand All @@ -113,6 +117,7 @@ def __init__(
on_change: Optional[Callable[[DataFrameType], None]] = None,
page_size: Optional[int] = 5,
limit: Optional[int] = None,
show_download: bool = True,
) -> None:
validate_no_integer_columns(df)
# This will raise an error if the dataframe type is not supported.
Expand Down Expand Up @@ -144,6 +149,7 @@ def __init__(
self._error: Optional[str] = None
self._last_transforms = Transformations([])
self._page_size = page_size or 5 # Default to 5 rows (.head())
self._show_download = show_download
validate_page_size(self._page_size)

super().__init__(
Expand All @@ -158,6 +164,7 @@ def __init__(
"dataframe-name": dataframe_name,
"total": self._manager.get_num_rows(force=False),
"page-size": page_size,
"show-download": show_download,
},
functions=(
Function(
Expand All @@ -175,6 +182,11 @@ def __init__(
arg_cls=SearchTableArgs,
function=self._search,
),
Function(
name="download_as",
arg_cls=DownloadAsArgs,
function=self._download_as,
),
),
)

Expand Down Expand Up @@ -266,6 +278,39 @@ def _search(self, args: SearchTableArgs) -> SearchTableResponse:
total_rows=result.get_num_rows(force=True) or 0,
)

def _download_as(self, args: DownloadAsArgs) -> str:
"""Download the transformed dataframe in the specified format.

Downloads the dataframe with all current transformations applied.

Args:
args (DownloadAsArgs): Arguments specifying the download format.
format must be one of 'csv', 'json', or 'parquet'.

Returns:
str: URL to download the data file.

Raises:
ValueError: If format is not supported.
"""
# Get transformed dataframe
df = self._value

# Get the table manager for the transformed data
manager = self._get_cached_table_manager(df, self._limit)

ext = args.format
if ext == "csv":
return mo_data.csv(manager.to_csv()).url
elif ext == "json":
return mo_data.json(manager.to_json()).url
elif ext == "parquet":
return mo_data.parquet(manager.to_parquet()).url
else:
raise ValueError(
"format must be one of 'csv', 'json', or 'parquet'."
)

def _apply_filters_query_sort(
self,
query: Optional[str],
Expand Down
131 changes: 130 additions & 1 deletion tests/_plugins/ui/_impl/dataframes/test_dataframe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
from typing import Any
from unittest.mock import Mock, patch

Expand All @@ -12,8 +13,13 @@
GetColumnValuesArgs,
GetColumnValuesResponse,
)
from marimo._plugins.ui._impl.table import SearchTableArgs, TableSearchError
from marimo._plugins.ui._impl.table import (
DownloadAsArgs,
SearchTableArgs,
TableSearchError,
)
from marimo._runtime.functions import EmptyArgs
from marimo._utils.data_uri import from_data_uri
from marimo._utils.platform import is_windows
from tests._data.mocks import create_dataframes

Expand Down Expand Up @@ -265,6 +271,129 @@ def test_dataframe_with_limit(df: Any) -> None:
)
assert search_result.total_rows == 100

@staticmethod
@pytest.mark.skipif(
not HAS_DEPS, reason="optional dependencies not installed"
)
def test_dataframe_show_download() -> None:
# default behavior
df = pd.DataFrame({"A": [1, 2, 3], "B": ["a", "b", "c"]})
subject = ui.dataframe(df)
assert subject._component_args["show-download"] is True

# show_download=True
subject = ui.dataframe(df, show_download=True)
assert subject._component_args["show-download"] is True

# show_download=False
subject = ui.dataframe(df, show_download=False)
assert subject._component_args["show-download"] is False

@staticmethod
@pytest.mark.skipif(
not HAS_DEPS, reason="optional dependencies not installed"
)
@pytest.mark.parametrize("format_type", ["csv", "json", "parquet"])
def test_dataframe_download_formats(format_type) -> None:
df = pd.DataFrame(
{
"cities": ["Newark", "New York", "Los Angeles"],
"population": [311549, 8336817, 3898747],
}
)
subject = ui.dataframe(df)

# no transformations
download_url = subject._download_as(DownloadAsArgs(format=format_type))
assert download_url.startswith("data:")

data_bytes = from_data_uri(download_url)[1]
assert len(data_bytes) > 0

@staticmethod
@pytest.mark.skipif(
not HAS_DEPS, reason="optional dependencies not installed"
)
def test_dataframe_download_with_transformations() -> None:
df = pd.DataFrame(
{
"name": ["Alice", "Bob", "Charlie"],
"age": [25, 30, 35],
"city": ["New York", "Newark", "Los Angeles"],
}
)
subject = ui.dataframe(df)

# Apply some transformations (would be done through the UI)
subject._value = df[df["age"] > 27]

# download with transformations applied
download_url = subject._download_as(DownloadAsArgs(format="json"))
data_bytes = from_data_uri(download_url)[1]

json_data = json.loads(data_bytes.decode("utf-8"))

assert len(json_data) == 2
names = [row["name"] for row in json_data]
assert "Bob" in names
assert "Charlie" in names
assert "Alice" not in names

@staticmethod
@pytest.mark.skipif(
not HAS_DEPS, reason="optional dependencies not installed"
)
def test_dataframe_download_empty() -> None:
df = pd.DataFrame({"A": [], "B": []})
subject = ui.dataframe(df)

download_url = subject._download_as(DownloadAsArgs(format="csv"))
data_bytes = from_data_uri(download_url)[1]

csv_content = data_bytes.decode("utf-8")
assert "A,B" in csv_content or "A" in csv_content

@staticmethod
@pytest.mark.skipif(
not HAS_DEPS, reason="optional dependencies not installed"
)
def test_dataframe_download_unsupported_format() -> None:
df = pd.DataFrame({"A": [1, 2, 3]})
subject = ui.dataframe(df)

# unsupported format
with pytest.raises(ValueError) as exc_info:
subject._download_as(DownloadAsArgs(format="xml"))

assert "format must be one of 'csv', 'json', or 'parquet'" in str(
exc_info.value
)

@staticmethod
@pytest.mark.skipif(
not HAS_DEPS, reason="optional dependencies not installed"
)
@pytest.mark.parametrize(
"df",
create_dataframes(
{"A": [1, 2, 3], "B": ["x", "y", "z"]},
exclude=["pyarrow", "duckdb", "lazy-polars"],
),
)
def test_dataframe_download_different_backends(df) -> None:
subject = ui.dataframe(df)

# Test that download works with different dataframe backends
for format_type in ["csv", "json", "parquet"]:
try:
download_url = subject._download_as(
DownloadAsArgs(format=format_type)
)
assert download_url.startswith("data:")
except Exception as e:
# Some backends might not support all formats
pytest.skip(f"Backend doesn't support {format_type}: {e}")

@staticmethod
@pytest.mark.parametrize(
"df",
Expand Down
Loading