Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions frontend/src/components/data-table/schemas.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/* Copyright 2024 Marimo. All rights reserved. */

import z from "zod";
import { rpc } from "@/plugins/core/rpc";

export type DownloadAsArgs = (req: {
format: "csv" | "json" | "parquet";
}) => Promise<string>;

export const DownloadAsSchema = rpc
.input(
z.object({
format: z.enum(["csv", "json", "parquet"]),
}),
)
.output(z.string());
10 changes: 6 additions & 4 deletions frontend/src/plugins/impl/DataTablePlugin.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ import {
import { usePanelOwnership } from "@/components/data-table/hooks/use-panel-ownership";
import { LoadingTable } from "@/components/data-table/loading-table";
import { RowViewerPanel } from "@/components/data-table/row-viewer-panel/row-viewer";
import {
type DownloadAsArgs,
DownloadAsSchema,
} from "@/components/data-table/schemas";
import {
type BinValues,
type ColumnHeaderStats,
Expand Down Expand Up @@ -191,7 +195,7 @@ interface Data<T> {

// eslint-disable-next-line @typescript-eslint/consistent-type-definitions
type DataTableFunctions = {
download_as: (req: { format: "csv" | "json" | "parquet" }) => Promise<string>;
download_as: DownloadAsArgs;
get_column_summaries: <T>(
opts: ColumnSummariesArgs,
) => Promise<ColumnSummaries<T>>;
Expand Down Expand Up @@ -265,9 +269,7 @@ export const DataTablePlugin = createPlugin<S>("marimo-table")
}),
)
.withFunctions<DataTableFunctions>({
download_as: rpc
.input(z.object({ format: z.enum(["csv", "json", "parquet"]) }))
.output(z.string()),
download_as: DownloadAsSchema,
get_column_summaries: rpc
.input(z.object({ precompute: z.boolean() }))
.output(
Expand Down
16 changes: 7 additions & 9 deletions frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ import { isEqual } from "lodash-es";
import { Code2Icon, DatabaseIcon, FunctionSquareIcon } from "lucide-react";
import { type JSX, memo, useEffect, useRef, useState } from "react";
import { z } from "zod";
import {
type DownloadAsArgs,
DownloadAsSchema,
} from "@/components/data-table/schemas";
import type { FieldTypesWithExternalType } from "@/components/data-table/types";
import { ReadonlyCode } from "@/components/editor/code/readonly-python-code";
import { Spinner } from "@/components/icons/spinner";
Expand Down Expand Up @@ -68,7 +72,7 @@ type PluginFunctions = {
data: TableData<T>;
total_rows: number;
}>;
download_as: (req: { format: "csv" | "json" | "parquet" }) => Promise<string>;
download_as: DownloadAsArgs;
};

// Value is selection, but it is not currently exposed to the user
Expand Down Expand Up @@ -127,13 +131,7 @@ export const DataFramePlugin = createPlugin<S>("marimo-dataframe")
total_rows: z.number(),
}),
),
download_as: rpc
.input(
z.object({
format: z.enum(["csv", "json", "parquet"]),
}),
)
.output(z.string()),
download_as: DownloadAsSchema,
})
.renderer((props) => (
<TableProviders>
Expand All @@ -152,7 +150,7 @@ interface DataTableProps extends Data, PluginFunctions {
setValue: (value: S) => void;
host: HTMLElement;
showDownload: boolean;
download_as: (req: { format: "csv" | "json" | "parquet" }) => Promise<string>;
download_as: DownloadAsArgs;
}

const EMPTY: Transformations = {
Expand Down
15 changes: 2 additions & 13 deletions marimo/_plugins/ui/_impl/dataframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
Union,
)

import marimo._output.data.data as mo_data
from marimo._output.rich_help import mddoc
from marimo._plugins.ui._core.ui_element import UIElement
from marimo._plugins.ui._impl.dataframes.transforms.apply import (
Expand All @@ -38,6 +37,7 @@
from marimo._plugins.ui._impl.tables.utils import (
get_table_manager,
)
from marimo._plugins.ui._impl.utils.dataframe import download_as
from marimo._plugins.validators import (
validate_no_integer_columns,
validate_page_size,
Expand Down Expand Up @@ -298,18 +298,7 @@ def _download_as(self, args: DownloadAsArgs) -> str:

# Get the table manager for the transformed data
manager = self._get_cached_table_manager(df, self._limit)

ext = args.format
if ext == "csv":
return mo_data.csv(manager.to_csv()).url
elif ext == "json":
return mo_data.json(manager.to_json()).url
elif ext == "parquet":
return mo_data.parquet(manager.to_parquet()).url
else:
raise ValueError(
"format must be one of 'csv', 'json', or 'parquet'."
)
return download_as(manager, args.format)

def _apply_filters_query_sort(
self,
Expand Down
20 changes: 8 additions & 12 deletions marimo/_plugins/ui/_impl/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@
TableManager,
)
from marimo._plugins.ui._impl.tables.utils import get_table_manager
from marimo._plugins.ui._impl.utils.dataframe import ListOrTuple, TableData
from marimo._plugins.ui._impl.utils.dataframe import (
ListOrTuple,
TableData,
download_as,
)
from marimo._plugins.validators import (
validate_no_integer_columns,
validate_page_size,
Expand Down Expand Up @@ -806,17 +810,9 @@ def _download_as(self, args: DownloadAsArgs) -> str:

# Remove the selection column before downloading
if isinstance(manager_candidate, TableManager):
manager = manager_candidate.drop_columns([INDEX_COLUMN_NAME])

ext = args.format
if ext == "csv":
return mo_data.csv(manager.to_csv()).url
elif ext == "json":
return mo_data.json(manager.to_json()).url
elif ext == "parquet":
return mo_data.parquet(manager.to_parquet()).url
else:
raise ValueError("format must be one of 'csv' or 'json'.")
return download_as(
manager_candidate, args.format, drop_marimo_index=True
)
else:
raise NotImplementedError(
"Download is not supported for this table format."
Expand Down
40 changes: 36 additions & 4 deletions marimo/_plugins/ui/_impl/utils/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations

from typing import (
TypeVar,
Union,
)
from typing import Any, TypeVar, Union

from narwhals.typing import IntoDataFrame

from marimo import _loggers
from marimo._output.data import data as mo_data
from marimo._output.mime import MIME
from marimo._plugins.core.web_component import JSONType
from marimo._plugins.ui._impl.tables.selection import INDEX_COLUMN_NAME
from marimo._plugins.ui._impl.tables.table_manager import TableManager

LOGGER = _loggers.marimo_logger()

Expand All @@ -26,3 +26,35 @@
dict[str, ListOrTuple[JSONType]],
IntoDataFrame,
]


def download_as(
manager: TableManager[Any], ext: str, drop_marimo_index: bool = False
) -> str:
"""Download the table data in the specified format.

Args:
manager (TableManager[Any]): The table manager to download.
ext (str): The format to download the table data in.
drop_marimo_index (bool, optional): Whether to drop the marimo selection column.
Defaults to False.

Raises:
ValueError: If unrecognized format.
NotImplementedError: If the table format is not supported.

Returns:
str: The URL to download the table data.
"""
if drop_marimo_index:
# Remove the selection column if exists
manager = manager.drop_columns([INDEX_COLUMN_NAME])

if ext == "csv":
return mo_data.csv(manager.to_csv()).url
elif ext == "json":
return mo_data.json(manager.to_json()).url
elif ext == "parquet":
return mo_data.parquet(manager.to_parquet()).url
else:
raise ValueError("format must be one of 'csv', 'json', or 'parquet'.")
Loading