Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
fb86632
Add multi-column sorting with stack-based behavior
lucharo Sep 5, 2025
fdbad05
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2025
1c8e68c
Remove Claude settings file
lucharo Sep 5, 2025
03f13ed
Fix multi-column sorting RPC serialization
lucharo Sep 7, 2025
00cac12
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2025
4038124
Add multi-column sorting with stack-based priority
lucharo Oct 3, 2025
819c912
Add tests for multi-column sorting logic
lucharo Oct 3, 2025
b567353
Merge remote-tracking branch 'origin/main' into multi-column-sort
lucharo Oct 3, 2025
2f95bd1
Merge remote-tracking branch 'upstream/main' into multi-column-sort
lucharo Oct 3, 2025
ac9c7f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 3, 2025
9022d36
Use list[SortArgs] instead of list[tuple] for cleaner API
lucharo Oct 3, 2025
cc78a02
Simplify sort_values logic in DefaultTableManager
lucharo Oct 3, 2025
a88a12d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 3, 2025
b42c4fc
Refactor: reduce duplication in sort UI and data conversion
lucharo Oct 3, 2025
12cb0d2
Address Light2Dark's PR review feedback
lucharo Oct 4, 2025
6a6f2e1
Fix linting issues
lucharo Oct 4, 2025
4dff80d
Fix circular import with TYPE_CHECKING
lucharo Oct 4, 2025
7deb726
Add missing SortDirection type import
lucharo Oct 4, 2025
b3840f5
Remove unnecessary tuple conversion in dataframe.py
lucharo Oct 4, 2025
f9dfab3
Update all tests to use SortArgs instead of tuples
lucharo Oct 5, 2025
573f6a6
Fix mixed-type and None sorting in DefaultTableManager
lucharo Oct 5, 2025
afde587
Address PR feedback: filter invalid sort columns and remove artifacts
lucharo Oct 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion frontend/src/components/data-table/data-table.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,33 @@ const DataTableInternal = <TData,>({
manualPagination: manualPagination,
getPaginationRowModel: getPaginationRowModel(),
// sorting
...(setSorting ? { onSortingChange: setSorting } : {}),
...(setSorting
? {
onSortingChange: (updaterOrValue) => {
// Custom sorting logic - stack behavior
const newSorting =
typeof updaterOrValue === "function"
? updaterOrValue(sorting || [])
: updaterOrValue;

// Implement stack behavior: if column already exists, remove it and add to front
const customSorting = newSorting.reduce(
(acc: typeof newSorting, sort) => {
// Remove any existing sorts for this column
const filtered = acc.filter((s) => s.id !== sort.id);
// Add the new sort to the front (most recent)
return [sort, ...filtered];
},
[],
);

setSorting(customSorting);
},
}
: {}),
manualSorting: manualSorting,
enableMultiSort: true,
isMultiSortEvent: () => true, // Always enable multi-sort (no shift key required)
getSortedRowModel: getSortedRowModel(),
// filtering
manualFiltering: true,
Expand Down
86 changes: 77 additions & 9 deletions frontend/src/components/data-table/header-items.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* Copyright 2024 Marimo. All rights reserved. */

import { PinLeftIcon, PinRightIcon } from "@radix-ui/react-icons";
import type { Column } from "@tanstack/react-table";
import type { Column, SortingState, Table } from "@tanstack/react-table";
import {
AlignJustifyIcon,
ArrowDownWideNarrowIcon,
Expand All @@ -14,6 +14,7 @@
ListFilterPlusIcon,
PinOffIcon,
WrapTextIcon,
XIcon,
} from "lucide-react";
import {
DropdownMenuItem,
Expand Down Expand Up @@ -159,25 +160,92 @@
const AscIcon = ArrowUpNarrowWideIcon;
const DescIcon = ArrowDownWideNarrowIcon;

export function renderSorts<TData, TValue>(column: Column<TData, TValue>) {
export function renderSorts<TData, TValue>(
column: Column<TData, TValue>,
table?: Table<TData>,
) {
if (!column.getCanSort()) {
return null;
}

// Try to get table from column (TanStack Table should provide this)
const tableFromColumn = (column as any).table || table;

Check failure on line 172 in frontend/src/components/data-table/header-items.tsx

View workflow job for this annotation

GitHub Actions / 🧹 Lint frontend

Unexpected any. Specify a different type

// If table is available (either passed or from column), use full multi-column sort functionality
if (tableFromColumn) {
const sortingState: SortingState = tableFromColumn.getState().sorting;
const currentSort = sortingState.find((s) => s.id === column.id);
const sortIndex = currentSort
? sortingState.indexOf(currentSort) + 1
: null;

return (
<>
<DropdownMenuItem onClick={() => column.toggleSorting(false, true)}>
<AscIcon className="mo-dropdown-icon" />
Sort Ascending
{sortIndex && currentSort && !currentSort.desc && (
<span className="ml-auto text-xs bg-blue-100 text-blue-800 px-1 rounded">
{sortIndex}
</span>
)}
</DropdownMenuItem>
<DropdownMenuItem onClick={() => column.toggleSorting(true, true)}>
<DescIcon className="mo-dropdown-icon" />
Sort Descending
{sortIndex && currentSort && currentSort.desc && (
<span className="ml-auto text-xs bg-blue-100 text-blue-800 px-1 rounded">
{sortIndex}
</span>
)}
</DropdownMenuItem>
{currentSort && (
<DropdownMenuItem onClick={() => column.clearSorting()}>
<XIcon className="mo-dropdown-icon" />
Remove Sort
<span className="ml-auto text-xs bg-red-100 text-red-800 px-1 rounded">
{sortIndex}
</span>
</DropdownMenuItem>
)}
{sortingState.length > 0 && (
<DropdownMenuItem onClick={() => tableFromColumn.resetSorting()}>
<FilterX className="mo-dropdown-icon" />
Clear All Sorts
</DropdownMenuItem>
)}
<DropdownMenuSeparator />
</>
);
}

// Fallback to simple sorting if table not provided
const isSorted = column.getIsSorted();

return (
<>
<DropdownMenuItem onClick={() => column.toggleSorting(false)}>
<DropdownMenuItem onClick={() => column.toggleSorting(false, true)}>
<AscIcon className="mo-dropdown-icon" />
Asc
Sort Ascending
{isSorted === "asc" && (
<span className="ml-auto text-xs bg-blue-100 text-blue-800 px-1 rounded">
</span>
)}
</DropdownMenuItem>
<DropdownMenuItem onClick={() => column.toggleSorting(true)}>
<DropdownMenuItem onClick={() => column.toggleSorting(true, true)}>
<DescIcon className="mo-dropdown-icon" />
Desc
Sort Descending
{isSorted === "desc" && (
<span className="ml-auto text-xs bg-blue-100 text-blue-800 px-1 rounded">
</span>
)}
</DropdownMenuItem>
{column.getIsSorted() && (
{isSorted && (
<DropdownMenuItem onClick={() => column.clearSorting()}>
<ChevronsUpDown className="mo-dropdown-icon" />
Clear sort
<XIcon className="mo-dropdown-icon" />
Remove Sort
</DropdownMenuItem>
)}
<DropdownMenuSeparator />
Expand Down
21 changes: 10 additions & 11 deletions frontend/src/plugins/impl/DataTablePlugin.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ type DataTableFunctions = {
) => Promise<ColumnSummaries<T>>;
search: <T>(req: {
sort?: {
by: string;
descending: boolean;
by: string[];
descending: boolean[];
};
query?: string;
filters?: ConditionType[];
Expand Down Expand Up @@ -282,7 +282,10 @@ export const DataTablePlugin = createPlugin<S>("marimo-table")
.input(
z.object({
sort: z
.object({ by: z.string(), descending: z.boolean() })
.object({
by: z.array(z.string()),
descending: z.array(z.boolean()),
})
.optional(),
query: z.string().optional(),
filters: z.array(ConditionSchema).optional(),
Expand Down Expand Up @@ -476,17 +479,13 @@ export const LoadingDataTableComponent = memo(
!props.lazy &&
!pageSizeChanged;

if (sorting.length > 1) {
Logger.warn("Multiple sort columns are not supported");
}

// If we have sort/search/filter, use the search function
const searchResultsPromise = search<T>({
sort:
sorting.length > 0
? {
by: sorting[0].id,
descending: sorting[0].desc,
by: sorting.map((column) => column.id),
descending: sorting.map((column) => column.desc),
}
: undefined,
query: searchQuery,
Expand Down Expand Up @@ -541,8 +540,8 @@ export const LoadingDataTableComponent = memo(
sort:
sorting.length > 0
? {
by: sorting[0].id,
descending: sorting[0].desc,
by: sorting.map((column) => column.id),
descending: sorting.map((column) => column.desc),
}
: undefined,
query: searchQuery,
Expand Down
9 changes: 6 additions & 3 deletions frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ type PluginFunctions = {
}>;
search: <T>(req: {
sort?: {
by: string;
descending: boolean;
by: string[];
descending: boolean[];
};
query?: string;
filters?: ConditionType[];
Expand Down Expand Up @@ -110,7 +110,10 @@ export const DataFramePlugin = createPlugin<S>("marimo-dataframe")
.input(
z.object({
sort: z
.object({ by: z.string(), descending: z.boolean() })
.object({
by: z.array(z.string()),
descending: z.array(z.boolean()),
})
.optional(),
query: z.string().optional(),
filters: z.array(ConditionSchema).optional(),
Expand Down
10 changes: 8 additions & 2 deletions marimo/_plugins/ui/_impl/dataframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,14 @@ def _apply_filters_query_sort(
if query:
result = result.search(query)

if sort and sort.by in result.get_column_names():
result = result.sort_values(sort.by, sort.descending)
if sort:
# Convert tuples to lists to match the sort_values method signature
by_list = list(sort.by)
descending_list = list(sort.descending)
# Check that all columns exist
existing_columns = set(result.get_column_names())
if all(col in existing_columns for col in by_list):
result = result.sort_values(by_list, descending_list)

return result

Expand Down
22 changes: 14 additions & 8 deletions marimo/_plugins/ui/_impl/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ class ColumnSummaries:
MaxColumnsType = Union[int, None, MaxColumnsNotProvided]


@dataclass(frozen=True)
class SortArgs:
by: tuple[ColumnName, ...]
descending: tuple[bool, ...]


@dataclass(frozen=True)
class SearchTableArgs:
page_size: int
Expand All @@ -135,12 +141,6 @@ class SearchTableResponse:
cell_styles: Optional[CellStyles] = None


@dataclass(frozen=True)
class SortArgs:
by: ColumnName
descending: bool


@dataclass
class GetRowIdsResponse:
row_ids: list[int]
Expand Down Expand Up @@ -1100,8 +1100,14 @@ def _apply_filters_query_sort(
if query:
result = result.search(query)

if sort and sort.by in result.get_column_names():
result = result.sort_values(sort.by, sort.descending)
if sort:
# Convert tuples to lists to match the sort_values method signature
by_list = list(sort.by)
descending_list = list(sort.descending)
# Check that all columns exist
existing_columns = set(result.get_column_names())
if all(col in existing_columns for col in by_list):
result = result.sort_values(by_list, descending_list)

return result

Expand Down
Loading
Loading