Skip to content

Commit 9b58a93

Browse files
lucharoclaudepre-commit-ci[bot]
authored
Add multi-column sorting with stack-based behavior (#6257)
## Summary Implements multi-column sorting for marimo tables with stack-based behavior. Clicking a column sort moves it to the highest priority (end of sort array). Clicking the same direction again removes that sort. Visual indicators show sort priority (1, 2, 3...) in dropdown menus. "Clear sort" button adapts to "Clear all sorts" when multiple columns are sorted. Backend uses `list[SortArgs]` where each `SortArgs` contains `by: ColumnName` and `descending: bool`. Frontend sends sort state as `[{by: string, descending: boolean}]`. Compatible with latest upstream table manager refactoring (IbisTableManager now extends NarwhalsTableManager). 🤖 Generated with [Claude Code](https://claude.ai/code) --------- Co-authored-by: Claude <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 42f6183 commit 9b58a93

File tree

19 files changed

+588
-181
lines changed

19 files changed

+588
-181
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/* Copyright 2024 Marimo. All rights reserved. */
2+
3+
import type { SortingState } from "@tanstack/react-table";
4+
import { describe, expect, it, vi } from "vitest";
5+
6+
describe("multi-column sorting logic", () => {
7+
// Extract the core sorting logic to test in isolation
8+
const handleSort = (
9+
columnId: string,
10+
desc: boolean,
11+
sortingState: SortingState,
12+
setSorting: (state: SortingState) => void,
13+
clearSorting: () => void,
14+
) => {
15+
const currentSort = sortingState.find((s) => s.id === columnId);
16+
17+
if (currentSort && currentSort.desc === desc) {
18+
// Clicking the same sort again - remove it
19+
clearSorting();
20+
} else {
21+
// New sort or different direction - move to end of stack
22+
const otherSorts = sortingState.filter((s) => s.id !== columnId);
23+
const newSort = { id: columnId, desc };
24+
setSorting([...otherSorts, newSort]);
25+
}
26+
};
27+
28+
it("implements stack-based sorting: moves re-clicked column to end", () => {
29+
const sortingState: SortingState = [
30+
{ id: "name", desc: false },
31+
{ id: "age", desc: false },
32+
];
33+
const setSorting = vi.fn();
34+
const clearSorting = vi.fn();
35+
36+
// Click Desc on age - should move age to end with desc=true
37+
handleSort("age", true, sortingState, setSorting, clearSorting);
38+
39+
expect(setSorting).toHaveBeenCalledWith([
40+
{ id: "name", desc: false },
41+
{ id: "age", desc: true },
42+
]);
43+
expect(clearSorting).not.toHaveBeenCalled();
44+
});
45+
46+
it("removes sort when clicking same direction twice", () => {
47+
const sortingState: SortingState = [{ id: "age", desc: false }];
48+
const setSorting = vi.fn();
49+
const clearSorting = vi.fn();
50+
51+
// Click Asc on age again - should remove the sort
52+
handleSort("age", false, sortingState, setSorting, clearSorting);
53+
54+
expect(clearSorting).toHaveBeenCalled();
55+
expect(setSorting).not.toHaveBeenCalled();
56+
});
57+
58+
it("adds new column to end of stack", () => {
59+
const sortingState: SortingState = [{ id: "name", desc: false }];
60+
const setSorting = vi.fn();
61+
const clearSorting = vi.fn();
62+
63+
// Click Asc on age - should add age to end
64+
handleSort("age", false, sortingState, setSorting, clearSorting);
65+
66+
expect(setSorting).toHaveBeenCalledWith([
67+
{ id: "name", desc: false },
68+
{ id: "age", desc: false },
69+
]);
70+
expect(clearSorting).not.toHaveBeenCalled();
71+
});
72+
73+
it("toggles sort direction when clicking opposite", () => {
74+
const sortingState: SortingState = [{ id: "age", desc: false }];
75+
const setSorting = vi.fn();
76+
const clearSorting = vi.fn();
77+
78+
// Click Desc on age - should toggle to descending
79+
handleSort("age", true, sortingState, setSorting, clearSorting);
80+
81+
expect(setSorting).toHaveBeenCalledWith([{ id: "age", desc: true }]);
82+
expect(clearSorting).not.toHaveBeenCalled();
83+
});
84+
85+
it("correctly calculates priority numbers", () => {
86+
const sortingState: SortingState = [
87+
{ id: "name", desc: false },
88+
{ id: "age", desc: true },
89+
{ id: "dept", desc: false },
90+
];
91+
92+
// Priority is index + 1
93+
const nameSort = sortingState.find((s) => s.id === "name");
94+
const namePriority = nameSort ? sortingState.indexOf(nameSort) + 1 : null;
95+
expect(namePriority).toBe(1);
96+
97+
const deptSort = sortingState.find((s) => s.id === "dept");
98+
const deptPriority = deptSort ? sortingState.indexOf(deptSort) + 1 : null;
99+
expect(deptPriority).toBe(3);
100+
});
101+
102+
it("handles removing column from middle of stack", () => {
103+
const sortingState: SortingState = [
104+
{ id: "name", desc: false },
105+
{ id: "age", desc: true },
106+
{ id: "dept", desc: false },
107+
];
108+
const setSorting = vi.fn();
109+
const clearSorting = vi.fn();
110+
111+
// Click Desc on age again - should remove it
112+
handleSort("age", true, sortingState, setSorting, clearSorting);
113+
114+
expect(clearSorting).toHaveBeenCalled();
115+
// After removal, dept should move from priority 3 to priority 2
116+
});
117+
});

frontend/src/components/data-table/column-header.tsx

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/* Copyright 2024 Marimo. All rights reserved. */
22
"use no memo";
33

4-
import type { Column } from "@tanstack/react-table";
4+
import type { Column, Table } from "@tanstack/react-table";
55
import { capitalize } from "lodash-es";
66
import { FilterIcon, MinusIcon, TextIcon, XIcon } from "lucide-react";
77
import { useMemo, useRef, useState } from "react";
@@ -68,13 +68,15 @@ interface DataTableColumnHeaderProps<TData, TValue>
6868
column: Column<TData, TValue>;
6969
header: React.ReactNode;
7070
calculateTopKRows?: CalculateTopKRows;
71+
table?: Table<TData>;
7172
}
7273

7374
export const DataTableColumnHeader = <TData, TValue>({
7475
column,
7576
header,
7677
className,
7778
calculateTopKRows,
79+
table,
7880
}: DataTableColumnHeaderProps<TData, TValue>) => {
7981
const [isFilterValueOpen, setIsFilterValueOpen] = useState(false);
8082
const { locale } = useLocale();
@@ -117,7 +119,7 @@ export const DataTableColumnHeader = <TData, TValue>({
117119
</DropdownMenuTrigger>
118120
<DropdownMenuContent align="start">
119121
{renderDataType(column)}
120-
{renderSorts(column)}
122+
{renderSorts(column, table)}
121123
{renderCopyColumn(column)}
122124
{renderColumnPinning(column)}
123125
{renderColumnWrapping(column)}

frontend/src/components/data-table/columns.tsx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ export function generateColumns<T>({
165165
return row[key as keyof T];
166166
},
167167

168-
header: ({ column }) => {
168+
header: ({ column, table }) => {
169169
const stats = chartSpecModel?.getColumnStats(key);
170170
const dtype = column.columnDef.meta?.dtype;
171171
const headerTitle = headerTooltip?.[key];
@@ -208,6 +208,7 @@ export function generateColumns<T>({
208208
header={headerWithTooltip}
209209
column={column}
210210
calculateTopKRows={calculateTopKRows}
211+
table={table}
211212
/>
212213
);
213214

frontend/src/components/data-table/data-table.tsx

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,13 @@ const DataTableInternal = <TData,>({
215215
manualPagination: manualPagination,
216216
getPaginationRowModel: getPaginationRowModel(),
217217
// sorting
218-
...(setSorting ? { onSortingChange: setSorting } : {}),
218+
...(setSorting
219+
? {
220+
onSortingChange: setSorting,
221+
}
222+
: {}),
219223
manualSorting: manualSorting,
224+
enableMultiSort: true,
220225
getSortedRowModel: getSortedRowModel(),
221226
// filtering
222227
manualFiltering: true,

frontend/src/components/data-table/header-items.tsx

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/* Copyright 2024 Marimo. All rights reserved. */
22

33
import { PinLeftIcon, PinRightIcon } from "@radix-ui/react-icons";
4-
import type { Column } from "@tanstack/react-table";
4+
import type { Column, SortDirection, Table } from "@tanstack/react-table";
55
import {
66
AlignJustifyIcon,
77
ArrowDownWideNarrowIcon,
@@ -163,27 +163,80 @@ export function renderCopyColumn<TData, TValue>(column: Column<TData, TValue>) {
163163
const AscIcon = ArrowUpNarrowWideIcon;
164164
const DescIcon = ArrowDownWideNarrowIcon;
165165

166-
export function renderSorts<TData, TValue>(column: Column<TData, TValue>) {
166+
export function renderSorts<TData, TValue>(
167+
column: Column<TData, TValue>,
168+
table?: Table<TData>,
169+
) {
167170
if (!column.getCanSort()) {
168171
return null;
169172
}
170173

174+
const sortDirection = column.getIsSorted();
175+
const sortingIndex = column.getSortIndex();
176+
177+
const sortingState = table?.getState().sorting;
178+
const hasMultiSort = sortingState?.length && sortingState.length > 1;
179+
180+
const renderSortIndex = () => {
181+
return (
182+
<span className="ml-auto text-xs font-medium">{sortingIndex + 1}</span>
183+
);
184+
};
185+
186+
const renderClearSort = () => {
187+
if (!sortDirection) {
188+
return null;
189+
}
190+
191+
if (!hasMultiSort) {
192+
// render clear sort for this column
193+
return (
194+
<DropdownMenuItem onClick={() => column.clearSorting()}>
195+
<ChevronsUpDown className="mo-dropdown-icon" />
196+
Clear sort
197+
</DropdownMenuItem>
198+
);
199+
}
200+
201+
// render clear sort for all columns
202+
return (
203+
<DropdownMenuItem onClick={() => table?.resetSorting()}>
204+
<ChevronsUpDown className="mo-dropdown-icon" />
205+
Clear all sorts
206+
</DropdownMenuItem>
207+
);
208+
};
209+
210+
const toggleSort = (direction: SortDirection) => {
211+
// Clear sort if clicking the same direction
212+
if (sortDirection === direction) {
213+
column.clearSorting();
214+
} else {
215+
// Toggle sort direction
216+
const descending = direction === "desc";
217+
column.toggleSorting(descending, true);
218+
}
219+
};
220+
171221
return (
172222
<>
173-
<DropdownMenuItem onClick={() => column.toggleSorting(false)}>
223+
<DropdownMenuItem
224+
onClick={() => toggleSort("asc")}
225+
className={sortDirection === "asc" ? "bg-accent" : ""}
226+
>
174227
<AscIcon className="mo-dropdown-icon" />
175228
Asc
229+
{sortDirection === "asc" && renderSortIndex()}
176230
</DropdownMenuItem>
177-
<DropdownMenuItem onClick={() => column.toggleSorting(true)}>
231+
<DropdownMenuItem
232+
onClick={() => toggleSort("desc")}
233+
className={sortDirection === "desc" ? "bg-accent" : ""}
234+
>
178235
<DescIcon className="mo-dropdown-icon" />
179236
Desc
237+
{sortDirection === "desc" && renderSortIndex()}
180238
</DropdownMenuItem>
181-
{column.getIsSorted() && (
182-
<DropdownMenuItem onClick={() => column.clearSorting()}>
183-
<ChevronsUpDown className="mo-dropdown-icon" />
184-
Clear sort
185-
</DropdownMenuItem>
186-
)}
239+
{renderClearSort()}
187240
<DropdownMenuSeparator />
188241
</>
189242
);

frontend/src/plugins/impl/DataTablePlugin.tsx

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ type DataTableFunctions = {
205205
sort?: {
206206
by: string;
207207
descending: boolean;
208-
};
208+
}[];
209209
query?: string;
210210
filters?: ConditionType[];
211211
page_number: number;
@@ -298,7 +298,12 @@ export const DataTablePlugin = createPlugin<S>("marimo-table")
298298
.input(
299299
z.object({
300300
sort: z
301-
.object({ by: z.string(), descending: z.boolean() })
301+
.array(
302+
z.object({
303+
by: z.string(),
304+
descending: z.boolean(),
305+
}),
306+
)
302307
.optional(),
303308
query: z.string().optional(),
304309
filters: z.array(ConditionSchema).optional(),
@@ -501,19 +506,15 @@ export const LoadingDataTableComponent = memo(
501506
!props.lazy &&
502507
!pageSizeChanged;
503508

504-
if (sorting.length > 1) {
505-
Logger.warn("Multiple sort columns are not supported");
506-
}
509+
// Convert sorting state to API format
510+
const sortArgs =
511+
sorting.length > 0
512+
? sorting.map((s) => ({ by: s.id, descending: s.desc }))
513+
: undefined;
507514

508515
// If we have sort/search/filter, use the search function
509516
const searchResultsPromise = search<T>({
510-
sort:
511-
sorting.length > 0
512-
? {
513-
by: sorting[0].id,
514-
descending: sorting[0].desc,
515-
}
516-
: undefined,
517+
sort: sortArgs,
517518
query: searchQuery,
518519
page_number: paginationState.pageIndex,
519520
page_size: paginationState.pageSize,
@@ -563,16 +564,15 @@ export const LoadingDataTableComponent = memo(
563564

564565
const getRow = useCallback(
565566
async (rowId: number) => {
567+
const sortArgs =
568+
sorting.length > 0
569+
? sorting.map((s) => ({ by: s.id, descending: s.desc }))
570+
: undefined;
571+
566572
const result = await search<T>({
567573
page_number: rowId,
568574
page_size: 1,
569-
sort:
570-
sorting.length > 0
571-
? {
572-
by: sorting[0].id,
573-
descending: sorting[0].desc,
574-
}
575-
: undefined,
575+
sort: sortArgs,
576576
query: searchQuery,
577577
filters: filters.flatMap((filter) => {
578578
return filterToFilterCondition(

frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ type PluginFunctions = {
6363
sort?: {
6464
by: string;
6565
descending: boolean;
66-
};
66+
}[];
6767
query?: string;
6868
filters?: ConditionType[];
6969
page_number: number;
@@ -117,7 +117,12 @@ export const DataFramePlugin = createPlugin<S>("marimo-dataframe")
117117
.input(
118118
z.object({
119119
sort: z
120-
.object({ by: z.string(), descending: z.boolean() })
120+
.array(
121+
z.object({
122+
by: z.string(),
123+
descending: z.boolean(),
124+
}),
125+
)
121126
.optional(),
122127
query: z.string().optional(),
123128
filters: z.array(ConditionSchema).optional(),

marimo/_plugins/ui/_impl/dataframes/dataframe.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,15 +303,18 @@ def _download_as(self, args: DownloadAsArgs) -> str:
303303
def _apply_filters_query_sort(
304304
self,
305305
query: Optional[str],
306-
sort: Optional[SortArgs],
306+
sort: Optional[list[SortArgs]],
307307
) -> TableManager[Any]:
308308
result = self._get_cached_table_manager(self._value, self._limit)
309309

310310
if query:
311311
result = result.search(query)
312312

313-
if sort and sort.by in result.get_column_names():
314-
result = result.sort_values(sort.by, sort.descending)
313+
if sort:
314+
existing_columns = set(result.get_column_names())
315+
valid_sort = [s for s in sort if s.by in existing_columns]
316+
if valid_sort:
317+
result = result.sort_values(valid_sort)
315318

316319
return result
317320

0 commit comments

Comments
 (0)