Skip to content

Commit a019e13

Browse files
authored
feat: Add "download" option to mo.ui.dataframe (#6492)
## 📝 Summary Fixes #5923. This PR adds a "Download" button to `mo.ui.dataframe`, allowing users to export data as CSV, JSON, or Parquet. It also introduces a `show_download` parameter to control its visibility. <img width="774" height="399" alt="image" src="https://github.com/user-attachments/assets/84d27b72-5a9a-4791-bef9-ddab6ac7c1fa" /> ## 📋 Checklist - [x] I have read the [contributor guidelines](https://github.com/marimo-team/marimo/blob/main/CONTRIBUTING.md). - [x] For large changes, or changes that affect the public API: this change was discussed or approved through an issue, on [Discord](https://marimo.io/discord?ref=pr), or the community [discussions](https://github.com/marimo-team/marimo/discussions) (Please provide a link if applicable). - [x] I have added tests for the changes made. - [x] I have run the code and verified that it works as expected. Disclosure: I used claude to help writing the tests.
1 parent df8d8f0 commit a019e13

File tree

4 files changed

+196
-6
lines changed

4 files changed

+196
-6
lines changed

frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ interface Data {
3838
label?: string | null;
3939
columns: ColumnDataTypes;
4040
pageSize: number;
41+
showDownload: boolean;
4142
}
4243

4344
// eslint-disable-next-line @typescript-eslint/consistent-type-definitions
@@ -67,6 +68,7 @@ type PluginFunctions = {
6768
data: TableData<T>;
6869
total_rows: number;
6970
}>;
71+
download_as: (req: { format: "csv" | "json" | "parquet" }) => Promise<string>;
7072
};
7173

7274
// Value is selection, but it is not currently exposed to the user
@@ -77,13 +79,14 @@ export const DataFramePlugin = createPlugin<S>("marimo-dataframe")
7779
z.object({
7880
label: z.string().nullish(),
7981
pageSize: z.number().default(5),
82+
showDownload: z.boolean().default(true),
8083
columns: z
8184
.array(z.tuple([z.string().or(z.number()), z.string(), z.string()]))
8285
.transform((value) => {
8386
const map = new Map<ColumnId, string>();
84-
value.forEach(([key, dataType]) =>
85-
map.set(key as ColumnId, dataType as DataType),
86-
);
87+
value.forEach(([key, dataType]) => {
88+
map.set(key as ColumnId, dataType as DataType);
89+
});
8790
return map;
8891
}),
8992
}),
@@ -124,6 +127,13 @@ export const DataFramePlugin = createPlugin<S>("marimo-dataframe")
124127
total_rows: z.number(),
125128
}),
126129
),
130+
download_as: rpc
131+
.input(
132+
z.object({
133+
format: z.enum(["csv", "json", "parquet"]),
134+
}),
135+
)
136+
.output(z.string()),
127137
})
128138
.renderer((props) => (
129139
<TableProviders>
@@ -141,6 +151,8 @@ interface DataTableProps extends Data, PluginFunctions {
141151
value: S;
142152
setValue: (value: S) => void;
143153
host: HTMLElement;
154+
showDownload: boolean;
155+
download_as: (req: { format: "csv" | "json" | "parquet" }) => Promise<string>;
144156
}
145157

146158
const EMPTY: Transformations = {
@@ -151,11 +163,13 @@ export const DataFrameComponent = memo(
151163
({
152164
columns,
153165
pageSize,
166+
showDownload,
154167
value,
155168
setValue,
156169
get_dataframe,
157170
get_column_values,
158171
search,
172+
download_as,
159173
host,
160174
}: DataTableProps): JSX.Element => {
161175
const { data, error, isPending } = useAsyncData(
@@ -270,8 +284,8 @@ export const DataFrameComponent = memo(
270284
pagination={true}
271285
fieldTypes={field_types}
272286
rowHeaders={row_headers || Arrays.EMPTY}
273-
showDownload={false}
274-
download_as={Functions.THROW}
287+
showDownload={showDownload}
288+
download_as={download_as}
275289
enableSearch={false}
276290
showFilters={false}
277291
search={search}

frontend/src/stories/dataframe.stories.tsx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ export const DataFrame: StoryObj = {
4040
get_dataframe={() => Promise.reject(new Error("not implemented"))}
4141
search={Functions.THROW}
4242
host={document.body}
43+
showDownload={false}
44+
download_as={async () => ""}
4345
/>
4446
);
4547
},

marimo/_plugins/ui/_impl/dataframes/dataframe.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Union,
1414
)
1515

16+
import marimo._output.data.data as mo_data
1617
from marimo._output.rich_help import mddoc
1718
from marimo._plugins.ui._core.ui_element import UIElement
1819
from marimo._plugins.ui._impl.dataframes.transforms.apply import (
@@ -24,6 +25,7 @@
2425
Transformations,
2526
)
2627
from marimo._plugins.ui._impl.table import (
28+
DownloadAsArgs,
2729
SearchTableArgs,
2830
SearchTableResponse,
2931
SortArgs,
@@ -101,6 +103,8 @@ class dataframe(UIElement[dict[str, Any], DataFrameType]):
101103
limit (Optional[int], optional): The number of items to load into memory, in case
102104
the data is remote and lazily fetched. This is likely true for SQL-backed
103105
dataframes via Ibis.
106+
show_download (bool, optional): Whether to show the download button.
107+
Defaults to True.
104108
on_change (Optional[Callable[[DataFrameType], None]], optional): Optional callback
105109
to run when this element's value changes.
106110
"""
@@ -113,6 +117,7 @@ def __init__(
113117
on_change: Optional[Callable[[DataFrameType], None]] = None,
114118
page_size: Optional[int] = 5,
115119
limit: Optional[int] = None,
120+
show_download: bool = True,
116121
) -> None:
117122
validate_no_integer_columns(df)
118123
# This will raise an error if the dataframe type is not supported.
@@ -144,6 +149,7 @@ def __init__(
144149
self._error: Optional[str] = None
145150
self._last_transforms = Transformations([])
146151
self._page_size = page_size or 5 # Default to 5 rows (.head())
152+
self._show_download = show_download
147153
validate_page_size(self._page_size)
148154

149155
super().__init__(
@@ -158,6 +164,7 @@ def __init__(
158164
"dataframe-name": dataframe_name,
159165
"total": self._manager.get_num_rows(force=False),
160166
"page-size": page_size,
167+
"show-download": show_download,
161168
},
162169
functions=(
163170
Function(
@@ -175,6 +182,11 @@ def __init__(
175182
arg_cls=SearchTableArgs,
176183
function=self._search,
177184
),
185+
Function(
186+
name="download_as",
187+
arg_cls=DownloadAsArgs,
188+
function=self._download_as,
189+
),
178190
),
179191
)
180192

@@ -266,6 +278,39 @@ def _search(self, args: SearchTableArgs) -> SearchTableResponse:
266278
total_rows=result.get_num_rows(force=True) or 0,
267279
)
268280

281+
def _download_as(self, args: DownloadAsArgs) -> str:
282+
"""Download the transformed dataframe in the specified format.
283+
284+
Downloads the dataframe with all current transformations applied.
285+
286+
Args:
287+
args (DownloadAsArgs): Arguments specifying the download format.
288+
format must be one of 'csv', 'json', or 'parquet'.
289+
290+
Returns:
291+
str: URL to download the data file.
292+
293+
Raises:
294+
ValueError: If format is not supported.
295+
"""
296+
# Get transformed dataframe
297+
df = self._value
298+
299+
# Get the table manager for the transformed data
300+
manager = self._get_cached_table_manager(df, self._limit)
301+
302+
ext = args.format
303+
if ext == "csv":
304+
return mo_data.csv(manager.to_csv()).url
305+
elif ext == "json":
306+
return mo_data.json(manager.to_json()).url
307+
elif ext == "parquet":
308+
return mo_data.parquet(manager.to_parquet()).url
309+
else:
310+
raise ValueError(
311+
"format must be one of 'csv', 'json', or 'parquet'."
312+
)
313+
269314
def _apply_filters_query_sort(
270315
self,
271316
query: Optional[str],

tests/_plugins/ui/_impl/dataframes/test_dataframe.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import json
34
from typing import Any
45
from unittest.mock import Mock, patch
56

@@ -12,8 +13,13 @@
1213
GetColumnValuesArgs,
1314
GetColumnValuesResponse,
1415
)
15-
from marimo._plugins.ui._impl.table import SearchTableArgs, TableSearchError
16+
from marimo._plugins.ui._impl.table import (
17+
DownloadAsArgs,
18+
SearchTableArgs,
19+
TableSearchError,
20+
)
1621
from marimo._runtime.functions import EmptyArgs
22+
from marimo._utils.data_uri import from_data_uri
1723
from marimo._utils.platform import is_windows
1824
from tests._data.mocks import create_dataframes
1925

@@ -265,6 +271,129 @@ def test_dataframe_with_limit(df: Any) -> None:
265271
)
266272
assert search_result.total_rows == 100
267273

274+
@staticmethod
275+
@pytest.mark.skipif(
276+
not HAS_DEPS, reason="optional dependencies not installed"
277+
)
278+
def test_dataframe_show_download() -> None:
279+
# default behavior
280+
df = pd.DataFrame({"A": [1, 2, 3], "B": ["a", "b", "c"]})
281+
subject = ui.dataframe(df)
282+
assert subject._component_args["show-download"] is True
283+
284+
# show_download=True
285+
subject = ui.dataframe(df, show_download=True)
286+
assert subject._component_args["show-download"] is True
287+
288+
# show_download=False
289+
subject = ui.dataframe(df, show_download=False)
290+
assert subject._component_args["show-download"] is False
291+
292+
@staticmethod
293+
@pytest.mark.skipif(
294+
not HAS_DEPS, reason="optional dependencies not installed"
295+
)
296+
@pytest.mark.parametrize("format_type", ["csv", "json", "parquet"])
297+
def test_dataframe_download_formats(format_type) -> None:
298+
df = pd.DataFrame(
299+
{
300+
"cities": ["Newark", "New York", "Los Angeles"],
301+
"population": [311549, 8336817, 3898747],
302+
}
303+
)
304+
subject = ui.dataframe(df)
305+
306+
# no transformations
307+
download_url = subject._download_as(DownloadAsArgs(format=format_type))
308+
assert download_url.startswith("data:")
309+
310+
data_bytes = from_data_uri(download_url)[1]
311+
assert len(data_bytes) > 0
312+
313+
@staticmethod
314+
@pytest.mark.skipif(
315+
not HAS_DEPS, reason="optional dependencies not installed"
316+
)
317+
def test_dataframe_download_with_transformations() -> None:
318+
df = pd.DataFrame(
319+
{
320+
"name": ["Alice", "Bob", "Charlie"],
321+
"age": [25, 30, 35],
322+
"city": ["New York", "Newark", "Los Angeles"],
323+
}
324+
)
325+
subject = ui.dataframe(df)
326+
327+
# Apply some transformations (would be done through the UI)
328+
subject._value = df[df["age"] > 27]
329+
330+
# download with transformations applied
331+
download_url = subject._download_as(DownloadAsArgs(format="json"))
332+
data_bytes = from_data_uri(download_url)[1]
333+
334+
json_data = json.loads(data_bytes.decode("utf-8"))
335+
336+
assert len(json_data) == 2
337+
names = [row["name"] for row in json_data]
338+
assert "Bob" in names
339+
assert "Charlie" in names
340+
assert "Alice" not in names
341+
342+
@staticmethod
343+
@pytest.mark.skipif(
344+
not HAS_DEPS, reason="optional dependencies not installed"
345+
)
346+
def test_dataframe_download_empty() -> None:
347+
df = pd.DataFrame({"A": [], "B": []})
348+
subject = ui.dataframe(df)
349+
350+
download_url = subject._download_as(DownloadAsArgs(format="csv"))
351+
data_bytes = from_data_uri(download_url)[1]
352+
353+
csv_content = data_bytes.decode("utf-8")
354+
assert "A,B" in csv_content or "A" in csv_content
355+
356+
@staticmethod
357+
@pytest.mark.skipif(
358+
not HAS_DEPS, reason="optional dependencies not installed"
359+
)
360+
def test_dataframe_download_unsupported_format() -> None:
361+
df = pd.DataFrame({"A": [1, 2, 3]})
362+
subject = ui.dataframe(df)
363+
364+
# unsupported format
365+
with pytest.raises(ValueError) as exc_info:
366+
subject._download_as(DownloadAsArgs(format="xml"))
367+
368+
assert "format must be one of 'csv', 'json', or 'parquet'" in str(
369+
exc_info.value
370+
)
371+
372+
@staticmethod
373+
@pytest.mark.skipif(
374+
not HAS_DEPS, reason="optional dependencies not installed"
375+
)
376+
@pytest.mark.parametrize(
377+
"df",
378+
create_dataframes(
379+
{"A": [1, 2, 3], "B": ["x", "y", "z"]},
380+
exclude=["pyarrow", "duckdb", "lazy-polars"],
381+
),
382+
)
383+
def test_dataframe_download_different_backends(df) -> None:
384+
subject = ui.dataframe(df)
385+
386+
# Test that download works with different dataframe backends
387+
for format_type in ["csv", "json", "parquet"]:
388+
try:
389+
download_url = subject._download_as(
390+
DownloadAsArgs(format=format_type)
391+
)
392+
assert download_url.startswith("data:")
393+
except Exception as e:
394+
# Some backends might not support all formats
395+
pytest.skip(f"Backend doesn't support {format_type}: {e}")
396+
268397
@staticmethod
269398
@pytest.mark.parametrize(
270399
"df",

0 commit comments

Comments
 (0)