diff --git a/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx b/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx index 3e08d2f6919..e89cee1795f 100644 --- a/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx +++ b/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx @@ -38,6 +38,7 @@ interface Data { label?: string | null; columns: ColumnDataTypes; pageSize: number; + showDownload: boolean; } // eslint-disable-next-line @typescript-eslint/consistent-type-definitions @@ -67,6 +68,7 @@ type PluginFunctions = { data: TableData; total_rows: number; }>; + download_as: (req: { format: "csv" | "json" | "parquet" }) => Promise; }; // Value is selection, but it is not currently exposed to the user @@ -77,13 +79,14 @@ export const DataFramePlugin = createPlugin("marimo-dataframe") z.object({ label: z.string().nullish(), pageSize: z.number().default(5), + showDownload: z.boolean().default(true), columns: z .array(z.tuple([z.string().or(z.number()), z.string(), z.string()])) .transform((value) => { const map = new Map(); - value.forEach(([key, dataType]) => - map.set(key as ColumnId, dataType as DataType), - ); + value.forEach(([key, dataType]) => { + map.set(key as ColumnId, dataType as DataType); + }); return map; }), }), @@ -124,6 +127,13 @@ export const DataFramePlugin = createPlugin("marimo-dataframe") total_rows: z.number(), }), ), + download_as: rpc + .input( + z.object({ + format: z.enum(["csv", "json", "parquet"]), + }), + ) + .output(z.string()), }) .renderer((props) => ( @@ -141,6 +151,8 @@ interface DataTableProps extends Data, PluginFunctions { value: S; setValue: (value: S) => void; host: HTMLElement; + showDownload: boolean; + download_as: (req: { format: "csv" | "json" | "parquet" }) => Promise; } const EMPTY: Transformations = { @@ -151,11 +163,13 @@ export const DataFrameComponent = memo( ({ columns, pageSize, + showDownload, value, setValue, get_dataframe, get_column_values, search, + download_as, host, }: DataTableProps): JSX.Element => { const { data, error, isPending } = useAsyncData( @@ -270,8 +284,8 @@ export const DataFrameComponent = memo( pagination={true} fieldTypes={field_types} rowHeaders={row_headers || Arrays.EMPTY} - showDownload={false} - download_as={Functions.THROW} + showDownload={showDownload} + download_as={download_as} enableSearch={false} showFilters={false} search={search} diff --git a/frontend/src/stories/dataframe.stories.tsx b/frontend/src/stories/dataframe.stories.tsx index 3d104d38df7..436eba27966 100644 --- a/frontend/src/stories/dataframe.stories.tsx +++ b/frontend/src/stories/dataframe.stories.tsx @@ -40,6 +40,8 @@ export const DataFrame: StoryObj = { get_dataframe={() => Promise.reject(new Error("not implemented"))} search={Functions.THROW} host={document.body} + showDownload={false} + download_as={async () => ""} /> ); }, diff --git a/marimo/_plugins/ui/_impl/dataframes/dataframe.py b/marimo/_plugins/ui/_impl/dataframes/dataframe.py index b49080f12d1..f5646227f74 100644 --- a/marimo/_plugins/ui/_impl/dataframes/dataframe.py +++ b/marimo/_plugins/ui/_impl/dataframes/dataframe.py @@ -13,6 +13,7 @@ Union, ) +import marimo._output.data.data as mo_data from marimo._output.rich_help import mddoc from marimo._plugins.ui._core.ui_element import UIElement from marimo._plugins.ui._impl.dataframes.transforms.apply import ( @@ -24,6 +25,7 @@ Transformations, ) from marimo._plugins.ui._impl.table import ( + DownloadAsArgs, SearchTableArgs, SearchTableResponse, SortArgs, @@ -101,6 +103,8 @@ class dataframe(UIElement[dict[str, Any], DataFrameType]): limit (Optional[int], optional): The number of items to load into memory, in case the data is remote and lazily fetched. This is likely true for SQL-backed dataframes via Ibis. + show_download (bool, optional): Whether to show the download button. + Defaults to True. on_change (Optional[Callable[[DataFrameType], None]], optional): Optional callback to run when this element's value changes. """ @@ -113,6 +117,7 @@ def __init__( on_change: Optional[Callable[[DataFrameType], None]] = None, page_size: Optional[int] = 5, limit: Optional[int] = None, + show_download: bool = True, ) -> None: validate_no_integer_columns(df) # This will raise an error if the dataframe type is not supported. @@ -144,6 +149,7 @@ def __init__( self._error: Optional[str] = None self._last_transforms = Transformations([]) self._page_size = page_size or 5 # Default to 5 rows (.head()) + self._show_download = show_download validate_page_size(self._page_size) super().__init__( @@ -158,6 +164,7 @@ def __init__( "dataframe-name": dataframe_name, "total": self._manager.get_num_rows(force=False), "page-size": page_size, + "show-download": show_download, }, functions=( Function( @@ -175,6 +182,11 @@ def __init__( arg_cls=SearchTableArgs, function=self._search, ), + Function( + name="download_as", + arg_cls=DownloadAsArgs, + function=self._download_as, + ), ), ) @@ -266,6 +278,39 @@ def _search(self, args: SearchTableArgs) -> SearchTableResponse: total_rows=result.get_num_rows(force=True) or 0, ) + def _download_as(self, args: DownloadAsArgs) -> str: + """Download the transformed dataframe in the specified format. + + Downloads the dataframe with all current transformations applied. + + Args: + args (DownloadAsArgs): Arguments specifying the download format. + format must be one of 'csv', 'json', or 'parquet'. + + Returns: + str: URL to download the data file. + + Raises: + ValueError: If format is not supported. + """ + # Get transformed dataframe + df = self._value + + # Get the table manager for the transformed data + manager = self._get_cached_table_manager(df, self._limit) + + ext = args.format + if ext == "csv": + return mo_data.csv(manager.to_csv()).url + elif ext == "json": + return mo_data.json(manager.to_json()).url + elif ext == "parquet": + return mo_data.parquet(manager.to_parquet()).url + else: + raise ValueError( + "format must be one of 'csv', 'json', or 'parquet'." + ) + def _apply_filters_query_sort( self, query: Optional[str], diff --git a/tests/_plugins/ui/_impl/dataframes/test_dataframe.py b/tests/_plugins/ui/_impl/dataframes/test_dataframe.py index 2c92a45a4ae..2749cb45e8e 100644 --- a/tests/_plugins/ui/_impl/dataframes/test_dataframe.py +++ b/tests/_plugins/ui/_impl/dataframes/test_dataframe.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from typing import Any from unittest.mock import Mock, patch @@ -12,8 +13,13 @@ GetColumnValuesArgs, GetColumnValuesResponse, ) -from marimo._plugins.ui._impl.table import SearchTableArgs, TableSearchError +from marimo._plugins.ui._impl.table import ( + DownloadAsArgs, + SearchTableArgs, + TableSearchError, +) from marimo._runtime.functions import EmptyArgs +from marimo._utils.data_uri import from_data_uri from marimo._utils.platform import is_windows from tests._data.mocks import create_dataframes @@ -265,6 +271,129 @@ def test_dataframe_with_limit(df: Any) -> None: ) assert search_result.total_rows == 100 + @staticmethod + @pytest.mark.skipif( + not HAS_DEPS, reason="optional dependencies not installed" + ) + def test_dataframe_show_download() -> None: + # default behavior + df = pd.DataFrame({"A": [1, 2, 3], "B": ["a", "b", "c"]}) + subject = ui.dataframe(df) + assert subject._component_args["show-download"] is True + + # show_download=True + subject = ui.dataframe(df, show_download=True) + assert subject._component_args["show-download"] is True + + # show_download=False + subject = ui.dataframe(df, show_download=False) + assert subject._component_args["show-download"] is False + + @staticmethod + @pytest.mark.skipif( + not HAS_DEPS, reason="optional dependencies not installed" + ) + @pytest.mark.parametrize("format_type", ["csv", "json", "parquet"]) + def test_dataframe_download_formats(format_type) -> None: + df = pd.DataFrame( + { + "cities": ["Newark", "New York", "Los Angeles"], + "population": [311549, 8336817, 3898747], + } + ) + subject = ui.dataframe(df) + + # no transformations + download_url = subject._download_as(DownloadAsArgs(format=format_type)) + assert download_url.startswith("data:") + + data_bytes = from_data_uri(download_url)[1] + assert len(data_bytes) > 0 + + @staticmethod + @pytest.mark.skipif( + not HAS_DEPS, reason="optional dependencies not installed" + ) + def test_dataframe_download_with_transformations() -> None: + df = pd.DataFrame( + { + "name": ["Alice", "Bob", "Charlie"], + "age": [25, 30, 35], + "city": ["New York", "Newark", "Los Angeles"], + } + ) + subject = ui.dataframe(df) + + # Apply some transformations (would be done through the UI) + subject._value = df[df["age"] > 27] + + # download with transformations applied + download_url = subject._download_as(DownloadAsArgs(format="json")) + data_bytes = from_data_uri(download_url)[1] + + json_data = json.loads(data_bytes.decode("utf-8")) + + assert len(json_data) == 2 + names = [row["name"] for row in json_data] + assert "Bob" in names + assert "Charlie" in names + assert "Alice" not in names + + @staticmethod + @pytest.mark.skipif( + not HAS_DEPS, reason="optional dependencies not installed" + ) + def test_dataframe_download_empty() -> None: + df = pd.DataFrame({"A": [], "B": []}) + subject = ui.dataframe(df) + + download_url = subject._download_as(DownloadAsArgs(format="csv")) + data_bytes = from_data_uri(download_url)[1] + + csv_content = data_bytes.decode("utf-8") + assert "A,B" in csv_content or "A" in csv_content + + @staticmethod + @pytest.mark.skipif( + not HAS_DEPS, reason="optional dependencies not installed" + ) + def test_dataframe_download_unsupported_format() -> None: + df = pd.DataFrame({"A": [1, 2, 3]}) + subject = ui.dataframe(df) + + # unsupported format + with pytest.raises(ValueError) as exc_info: + subject._download_as(DownloadAsArgs(format="xml")) + + assert "format must be one of 'csv', 'json', or 'parquet'" in str( + exc_info.value + ) + + @staticmethod + @pytest.mark.skipif( + not HAS_DEPS, reason="optional dependencies not installed" + ) + @pytest.mark.parametrize( + "df", + create_dataframes( + {"A": [1, 2, 3], "B": ["x", "y", "z"]}, + exclude=["pyarrow", "duckdb", "lazy-polars"], + ), + ) + def test_dataframe_download_different_backends(df) -> None: + subject = ui.dataframe(df) + + # Test that download works with different dataframe backends + for format_type in ["csv", "json", "parquet"]: + try: + download_url = subject._download_as( + DownloadAsArgs(format=format_type) + ) + assert download_url.startswith("data:") + except Exception as e: + # Some backends might not support all formats + pytest.skip(f"Backend doesn't support {format_type}: {e}") + @staticmethod @pytest.mark.parametrize( "df",