Skip to content

Commit 666fc69

Browse files
authored
share types and function between table and dataframe for download (#6497)
## 📝 Summary <!-- Provide a concise summary of what this pull request is addressing. If this PR fixes any issues, list them here by number (e.g., Fixes #123). --> This just helps with future maintainability by reusing types and functions. ## 🔍 Description of Changes <!-- Detail the specific changes made in this pull request. Explain the problem addressed and how it was resolved. If applicable, provide before and after comparisons, screenshots, or any relevant details to help reviewers understand the changes easily. --> ## 📋 Checklist - [x] I have read the [contributor guidelines](https://github.com/marimo-team/marimo/blob/main/CONTRIBUTING.md). - [ ] For large changes, or changes that affect the public API: this change was discussed or approved through an issue, on [Discord](https://marimo.io/discord?ref=pr), or the community [discussions](https://github.com/marimo-team/marimo/discussions) (Please provide a link if applicable). - [ ] I have added tests for the changes made. - [x] I have run the code and verified that it works as expected.
1 parent 570fd0c commit 666fc69

6 files changed

Lines changed: 75 additions & 42 deletions

File tree

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
/* Copyright 2024 Marimo. All rights reserved. */
2+
3+
import z from "zod";
4+
import { rpc } from "@/plugins/core/rpc";
5+
6+
export type DownloadAsArgs = (req: {
7+
format: "csv" | "json" | "parquet";
8+
}) => Promise<string>;
9+
10+
export const DownloadAsSchema = rpc
11+
.input(
12+
z.object({
13+
format: z.enum(["csv", "json", "parquet"]),
14+
}),
15+
)
16+
.output(z.string());

frontend/src/plugins/impl/DataTablePlugin.tsx

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ import {
3636
import { usePanelOwnership } from "@/components/data-table/hooks/use-panel-ownership";
3737
import { LoadingTable } from "@/components/data-table/loading-table";
3838
import { RowViewerPanel } from "@/components/data-table/row-viewer-panel/row-viewer";
39+
import {
40+
type DownloadAsArgs,
41+
DownloadAsSchema,
42+
} from "@/components/data-table/schemas";
3943
import {
4044
type BinValues,
4145
type ColumnHeaderStats,
@@ -191,7 +195,7 @@ interface Data<T> {
191195

192196
// eslint-disable-next-line @typescript-eslint/consistent-type-definitions
193197
type DataTableFunctions = {
194-
download_as: (req: { format: "csv" | "json" | "parquet" }) => Promise<string>;
198+
download_as: DownloadAsArgs;
195199
get_column_summaries: <T>(
196200
opts: ColumnSummariesArgs,
197201
) => Promise<ColumnSummaries<T>>;
@@ -265,9 +269,7 @@ export const DataTablePlugin = createPlugin<S>("marimo-table")
265269
}),
266270
)
267271
.withFunctions<DataTableFunctions>({
268-
download_as: rpc
269-
.input(z.object({ format: z.enum(["csv", "json", "parquet"]) }))
270-
.output(z.string()),
272+
download_as: DownloadAsSchema,
271273
get_column_summaries: rpc
272274
.input(z.object({ precompute: z.boolean() }))
273275
.output(

frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ import { isEqual } from "lodash-es";
44
import { Code2Icon, DatabaseIcon, FunctionSquareIcon } from "lucide-react";
55
import { type JSX, memo, useEffect, useRef, useState } from "react";
66
import { z } from "zod";
7+
import {
8+
type DownloadAsArgs,
9+
DownloadAsSchema,
10+
} from "@/components/data-table/schemas";
711
import type { FieldTypesWithExternalType } from "@/components/data-table/types";
812
import { ReadonlyCode } from "@/components/editor/code/readonly-python-code";
913
import { Spinner } from "@/components/icons/spinner";
@@ -68,7 +72,7 @@ type PluginFunctions = {
6872
data: TableData<T>;
6973
total_rows: number;
7074
}>;
71-
download_as: (req: { format: "csv" | "json" | "parquet" }) => Promise<string>;
75+
download_as: DownloadAsArgs;
7276
};
7377

7478
// Value is selection, but it is not currently exposed to the user
@@ -127,13 +131,7 @@ export const DataFramePlugin = createPlugin<S>("marimo-dataframe")
127131
total_rows: z.number(),
128132
}),
129133
),
130-
download_as: rpc
131-
.input(
132-
z.object({
133-
format: z.enum(["csv", "json", "parquet"]),
134-
}),
135-
)
136-
.output(z.string()),
134+
download_as: DownloadAsSchema,
137135
})
138136
.renderer((props) => (
139137
<TableProviders>
@@ -152,7 +150,7 @@ interface DataTableProps extends Data, PluginFunctions {
152150
setValue: (value: S) => void;
153151
host: HTMLElement;
154152
showDownload: boolean;
155-
download_as: (req: { format: "csv" | "json" | "parquet" }) => Promise<string>;
153+
download_as: DownloadAsArgs;
156154
}
157155

158156
const EMPTY: Transformations = {

marimo/_plugins/ui/_impl/dataframes/dataframe.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
Union,
1414
)
1515

16-
import marimo._output.data.data as mo_data
1716
from marimo._output.rich_help import mddoc
1817
from marimo._plugins.ui._core.ui_element import UIElement
1918
from marimo._plugins.ui._impl.dataframes.transforms.apply import (
@@ -38,6 +37,7 @@
3837
from marimo._plugins.ui._impl.tables.utils import (
3938
get_table_manager,
4039
)
40+
from marimo._plugins.ui._impl.utils.dataframe import download_as
4141
from marimo._plugins.validators import (
4242
validate_no_integer_columns,
4343
validate_page_size,
@@ -298,18 +298,7 @@ def _download_as(self, args: DownloadAsArgs) -> str:
298298

299299
# Get the table manager for the transformed data
300300
manager = self._get_cached_table_manager(df, self._limit)
301-
302-
ext = args.format
303-
if ext == "csv":
304-
return mo_data.csv(manager.to_csv()).url
305-
elif ext == "json":
306-
return mo_data.json(manager.to_json()).url
307-
elif ext == "parquet":
308-
return mo_data.parquet(manager.to_parquet()).url
309-
else:
310-
raise ValueError(
311-
"format must be one of 'csv', 'json', or 'parquet'."
312-
)
301+
return download_as(manager, args.format)
313302

314303
def _apply_filters_query_sort(
315304
self,

marimo/_plugins/ui/_impl/table.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@
4848
TableManager,
4949
)
5050
from marimo._plugins.ui._impl.tables.utils import get_table_manager
51-
from marimo._plugins.ui._impl.utils.dataframe import ListOrTuple, TableData
51+
from marimo._plugins.ui._impl.utils.dataframe import (
52+
ListOrTuple,
53+
TableData,
54+
download_as,
55+
)
5256
from marimo._plugins.validators import (
5357
validate_no_integer_columns,
5458
validate_page_size,
@@ -806,17 +810,9 @@ def _download_as(self, args: DownloadAsArgs) -> str:
806810

807811
# Remove the selection column before downloading
808812
if isinstance(manager_candidate, TableManager):
809-
manager = manager_candidate.drop_columns([INDEX_COLUMN_NAME])
810-
811-
ext = args.format
812-
if ext == "csv":
813-
return mo_data.csv(manager.to_csv()).url
814-
elif ext == "json":
815-
return mo_data.json(manager.to_json()).url
816-
elif ext == "parquet":
817-
return mo_data.parquet(manager.to_parquet()).url
818-
else:
819-
raise ValueError("format must be one of 'csv' or 'json'.")
813+
return download_as(
814+
manager_candidate, args.format, drop_marimo_index=True
815+
)
820816
else:
821817
raise NotImplementedError(
822818
"Download is not supported for this table format."

marimo/_plugins/ui/_impl/utils/dataframe.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# Copyright 2024 Marimo. All rights reserved.
22
from __future__ import annotations
33

4-
from typing import (
5-
TypeVar,
6-
Union,
7-
)
4+
from typing import Any, TypeVar, Union
85

96
from narwhals.typing import IntoDataFrame
107

118
from marimo import _loggers
9+
from marimo._output.data import data as mo_data
1210
from marimo._output.mime import MIME
1311
from marimo._plugins.core.web_component import JSONType
12+
from marimo._plugins.ui._impl.tables.selection import INDEX_COLUMN_NAME
13+
from marimo._plugins.ui._impl.tables.table_manager import TableManager
1414

1515
LOGGER = _loggers.marimo_logger()
1616

@@ -26,3 +26,35 @@
2626
dict[str, ListOrTuple[JSONType]],
2727
IntoDataFrame,
2828
]
29+
30+
31+
def download_as(
32+
manager: TableManager[Any], ext: str, drop_marimo_index: bool = False
33+
) -> str:
34+
"""Download the table data in the specified format.
35+
36+
Args:
37+
manager (TableManager[Any]): The table manager to download.
38+
ext (str): The format to download the table data in.
39+
drop_marimo_index (bool, optional): Whether to drop the marimo selection column.
40+
Defaults to False.
41+
42+
Raises:
43+
ValueError: If unrecognized format.
44+
NotImplementedError: If the table format is not supported.
45+
46+
Returns:
47+
str: The URL to download the table data.
48+
"""
49+
if drop_marimo_index:
50+
# Remove the selection column if exists
51+
manager = manager.drop_columns([INDEX_COLUMN_NAME])
52+
53+
if ext == "csv":
54+
return mo_data.csv(manager.to_csv()).url
55+
elif ext == "json":
56+
return mo_data.json(manager.to_json()).url
57+
elif ext == "parquet":
58+
return mo_data.parquet(manager.to_parquet()).url
59+
else:
60+
raise ValueError("format must be one of 'csv', 'json', or 'parquet'.")

0 commit comments

Comments
 (0)