Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
fb86632
Add multi-column sorting with stack-based behavior
lucharo Sep 5, 2025
fdbad05
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2025
1c8e68c
Remove Claude settings file
lucharo Sep 5, 2025
03f13ed
Fix multi-column sorting RPC serialization
lucharo Sep 7, 2025
00cac12
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2025
4038124
Add multi-column sorting with stack-based priority
lucharo Oct 3, 2025
819c912
Add tests for multi-column sorting logic
lucharo Oct 3, 2025
b567353
Merge remote-tracking branch 'origin/main' into multi-column-sort
lucharo Oct 3, 2025
2f95bd1
Merge remote-tracking branch 'upstream/main' into multi-column-sort
lucharo Oct 3, 2025
ac9c7f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 3, 2025
9022d36
Use list[SortArgs] instead of list[tuple] for cleaner API
lucharo Oct 3, 2025
cc78a02
Simplify sort_values logic in DefaultTableManager
lucharo Oct 3, 2025
a88a12d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 3, 2025
b42c4fc
Refactor: reduce duplication in sort UI and data conversion
lucharo Oct 3, 2025
12cb0d2
Address Light2Dark's PR review feedback
lucharo Oct 4, 2025
6a6f2e1
Fix linting issues
lucharo Oct 4, 2025
4dff80d
Fix circular import with TYPE_CHECKING
lucharo Oct 4, 2025
7deb726
Add missing SortDirection type import
lucharo Oct 4, 2025
b3840f5
Remove unnecessary tuple conversion in dataframe.py
lucharo Oct 4, 2025
f9dfab3
Update all tests to use SortArgs instead of tuples
lucharo Oct 5, 2025
573f6a6
Fix mixed-type and None sorting in DefaultTableManager
lucharo Oct 5, 2025
afde587
Address PR feedback: filter invalid sort columns and remove artifacts
lucharo Oct 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .claude/settings.local.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"outputStyle": "friendly-educational"
}
Empty file added P_empty
Empty file.
117 changes: 117 additions & 0 deletions frontend/src/components/data-table/__tests__/header-items.test.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/* Copyright 2024 Marimo. All rights reserved. */

import type { SortingState } from "@tanstack/react-table";
import { describe, expect, it, vi } from "vitest";

describe("multi-column sorting logic", () => {
// Extract the core sorting logic to test in isolation
const handleSort = (
columnId: string,
desc: boolean,
sortingState: SortingState,
setSorting: (state: SortingState) => void,
clearSorting: () => void,
) => {
const currentSort = sortingState.find((s) => s.id === columnId);

if (currentSort && currentSort.desc === desc) {
// Clicking the same sort again - remove it
clearSorting();
} else {
// New sort or different direction - move to end of stack
const otherSorts = sortingState.filter((s) => s.id !== columnId);
const newSort = { id: columnId, desc };
setSorting([...otherSorts, newSort]);
}
};

it("implements stack-based sorting: moves re-clicked column to end", () => {
const sortingState: SortingState = [
{ id: "name", desc: false },
{ id: "age", desc: false },
];
const setSorting = vi.fn();
const clearSorting = vi.fn();

// Click Desc on age - should move age to end with desc=true
handleSort("age", true, sortingState, setSorting, clearSorting);

expect(setSorting).toHaveBeenCalledWith([
{ id: "name", desc: false },
{ id: "age", desc: true },
]);
expect(clearSorting).not.toHaveBeenCalled();
});

it("removes sort when clicking same direction twice", () => {
const sortingState: SortingState = [{ id: "age", desc: false }];
const setSorting = vi.fn();
const clearSorting = vi.fn();

// Click Asc on age again - should remove the sort
handleSort("age", false, sortingState, setSorting, clearSorting);

expect(clearSorting).toHaveBeenCalled();
expect(setSorting).not.toHaveBeenCalled();
});

it("adds new column to end of stack", () => {
const sortingState: SortingState = [{ id: "name", desc: false }];
const setSorting = vi.fn();
const clearSorting = vi.fn();

// Click Asc on age - should add age to end
handleSort("age", false, sortingState, setSorting, clearSorting);

expect(setSorting).toHaveBeenCalledWith([
{ id: "name", desc: false },
{ id: "age", desc: false },
]);
expect(clearSorting).not.toHaveBeenCalled();
});

it("toggles sort direction when clicking opposite", () => {
const sortingState: SortingState = [{ id: "age", desc: false }];
const setSorting = vi.fn();
const clearSorting = vi.fn();

// Click Desc on age - should toggle to descending
handleSort("age", true, sortingState, setSorting, clearSorting);

expect(setSorting).toHaveBeenCalledWith([{ id: "age", desc: true }]);
expect(clearSorting).not.toHaveBeenCalled();
});

it("correctly calculates priority numbers", () => {
const sortingState: SortingState = [
{ id: "name", desc: false },
{ id: "age", desc: true },
{ id: "dept", desc: false },
];

// Priority is index + 1
const nameSort = sortingState.find((s) => s.id === "name");
const namePriority = nameSort ? sortingState.indexOf(nameSort) + 1 : null;
expect(namePriority).toBe(1);

const deptSort = sortingState.find((s) => s.id === "dept");
const deptPriority = deptSort ? sortingState.indexOf(deptSort) + 1 : null;
expect(deptPriority).toBe(3);
});

it("handles removing column from middle of stack", () => {
const sortingState: SortingState = [
{ id: "name", desc: false },
{ id: "age", desc: true },
{ id: "dept", desc: false },
];
const setSorting = vi.fn();
const clearSorting = vi.fn();

// Click Desc on age again - should remove it
handleSort("age", true, sortingState, setSorting, clearSorting);

expect(clearSorting).toHaveBeenCalled();
// After removal, dept should move from priority 3 to priority 2
});
});
6 changes: 4 additions & 2 deletions frontend/src/components/data-table/column-header.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* Copyright 2024 Marimo. All rights reserved. */
"use no memo";

import type { Column } from "@tanstack/react-table";
import type { Column, Table } from "@tanstack/react-table";
import { capitalize } from "lodash-es";
import { FilterIcon, MinusIcon, TextIcon, XIcon } from "lucide-react";
import { useMemo, useRef, useState } from "react";
Expand Down Expand Up @@ -68,13 +68,15 @@ interface DataTableColumnHeaderProps<TData, TValue>
column: Column<TData, TValue>;
header: React.ReactNode;
calculateTopKRows?: CalculateTopKRows;
table?: Table<TData>;
}

export const DataTableColumnHeader = <TData, TValue>({
column,
header,
className,
calculateTopKRows,
table,
}: DataTableColumnHeaderProps<TData, TValue>) => {
const [isFilterValueOpen, setIsFilterValueOpen] = useState(false);
const { locale } = useLocale();
Expand Down Expand Up @@ -117,7 +119,7 @@ export const DataTableColumnHeader = <TData, TValue>({
</DropdownMenuTrigger>
<DropdownMenuContent align="start">
{renderDataType(column)}
{renderSorts(column)}
{renderSorts(column, table)}
{renderCopyColumn(column)}
{renderColumnPinning(column)}
{renderColumnWrapping(column)}
Expand Down
3 changes: 2 additions & 1 deletion frontend/src/components/data-table/columns.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ export function generateColumns<T>({
return row[key as keyof T];
},

header: ({ column }) => {
header: ({ column, table }) => {
const stats = chartSpecModel?.getColumnStats(key);
const dtype = column.columnDef.meta?.dtype;
const headerTitle = headerTooltip?.[key];
Expand Down Expand Up @@ -208,6 +208,7 @@ export function generateColumns<T>({
header={headerWithTooltip}
column={column}
calculateTopKRows={calculateTopKRows}
table={table}
/>
);

Expand Down
7 changes: 6 additions & 1 deletion frontend/src/components/data-table/data-table.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,13 @@ const DataTableInternal = <TData,>({
manualPagination: manualPagination,
getPaginationRowModel: getPaginationRowModel(),
// sorting
...(setSorting ? { onSortingChange: setSorting } : {}),
...(setSorting
? {
onSortingChange: setSorting,
}
: {}),
manualSorting: manualSorting,
enableMultiSort: true,
getSortedRowModel: getSortedRowModel(),
// filtering
manualFiltering: true,
Expand Down
73 changes: 63 additions & 10 deletions frontend/src/components/data-table/header-items.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* Copyright 2024 Marimo. All rights reserved. */

import { PinLeftIcon, PinRightIcon } from "@radix-ui/react-icons";
import type { Column } from "@tanstack/react-table";
import type { Column, SortDirection, Table } from "@tanstack/react-table";
import {
AlignJustifyIcon,
ArrowDownWideNarrowIcon,
Expand Down Expand Up @@ -163,27 +163,80 @@ export function renderCopyColumn<TData, TValue>(column: Column<TData, TValue>) {
const AscIcon = ArrowUpNarrowWideIcon;
const DescIcon = ArrowDownWideNarrowIcon;

export function renderSorts<TData, TValue>(column: Column<TData, TValue>) {
export function renderSorts<TData, TValue>(
column: Column<TData, TValue>,
table?: Table<TData>,
) {
if (!column.getCanSort()) {
return null;
}

const sortDirection = column.getIsSorted();
const sortingIndex = column.getSortIndex();

const sortingState = table?.getState().sorting;
const hasMultiSort = sortingState?.length && sortingState.length > 1;

const renderSortIndex = () => {
return (
<span className="ml-auto text-xs font-medium">{sortingIndex + 1}</span>
);
};

const renderClearSort = () => {
if (!sortDirection) {
return null;
}

if (!hasMultiSort) {
// render clear sort for this column
return (
<DropdownMenuItem onClick={() => column.clearSorting()}>
<ChevronsUpDown className="mo-dropdown-icon" />
Clear sort
</DropdownMenuItem>
);
}

// render clear sort for all columns
return (
<DropdownMenuItem onClick={() => table?.resetSorting()}>
<ChevronsUpDown className="mo-dropdown-icon" />
Clear all sorts
</DropdownMenuItem>
);
};

const toggleSort = (direction: SortDirection) => {
// Clear sort if clicking the same direction
if (sortDirection === direction) {
column.clearSorting();
} else {
// Toggle sort direction
const descending = direction === "desc";
column.toggleSorting(descending, true);
}
};

return (
<>
<DropdownMenuItem onClick={() => column.toggleSorting(false)}>
<DropdownMenuItem
onClick={() => toggleSort("asc")}
className={sortDirection === "asc" ? "bg-accent" : ""}
>
<AscIcon className="mo-dropdown-icon" />
Asc
{sortDirection === "asc" && renderSortIndex()}
</DropdownMenuItem>
<DropdownMenuItem onClick={() => column.toggleSorting(true)}>
<DropdownMenuItem
onClick={() => toggleSort("desc")}
className={sortDirection === "desc" ? "bg-accent" : ""}
>
<DescIcon className="mo-dropdown-icon" />
Desc
{sortDirection === "desc" && renderSortIndex()}
</DropdownMenuItem>
{column.getIsSorted() && (
<DropdownMenuItem onClick={() => column.clearSorting()}>
<ChevronsUpDown className="mo-dropdown-icon" />
Clear sort
</DropdownMenuItem>
)}
{renderClearSort()}
<DropdownMenuSeparator />
</>
);
Expand Down
38 changes: 19 additions & 19 deletions frontend/src/plugins/impl/DataTablePlugin.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ type DataTableFunctions = {
sort?: {
by: string;
descending: boolean;
};
}[];
query?: string;
filters?: ConditionType[];
page_number: number;
Expand Down Expand Up @@ -298,7 +298,12 @@ export const DataTablePlugin = createPlugin<S>("marimo-table")
.input(
z.object({
sort: z
.object({ by: z.string(), descending: z.boolean() })
.array(
z.object({
by: z.string(),
descending: z.boolean(),
}),
)
.optional(),
query: z.string().optional(),
filters: z.array(ConditionSchema).optional(),
Expand Down Expand Up @@ -501,19 +506,15 @@ export const LoadingDataTableComponent = memo(
!props.lazy &&
!pageSizeChanged;

if (sorting.length > 1) {
Logger.warn("Multiple sort columns are not supported");
}
// Convert sorting state to API format
const sortArgs =
sorting.length > 0
? sorting.map((s) => ({ by: s.id, descending: s.desc }))
: undefined;

// If we have sort/search/filter, use the search function
const searchResultsPromise = search<T>({
sort:
sorting.length > 0
? {
by: sorting[0].id,
descending: sorting[0].desc,
}
: undefined,
sort: sortArgs,
query: searchQuery,
page_number: paginationState.pageIndex,
page_size: paginationState.pageSize,
Expand Down Expand Up @@ -563,16 +564,15 @@ export const LoadingDataTableComponent = memo(

const getRow = useCallback(
async (rowId: number) => {
const sortArgs =
sorting.length > 0
? sorting.map((s) => ({ by: s.id, descending: s.desc }))
: undefined;

const result = await search<T>({
page_number: rowId,
page_size: 1,
sort:
sorting.length > 0
? {
by: sorting[0].id,
descending: sorting[0].desc,
}
: undefined,
sort: sortArgs,
query: searchQuery,
filters: filters.flatMap((filter) => {
return filterToFilterCondition(
Expand Down
9 changes: 7 additions & 2 deletions frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type PluginFunctions = {
sort?: {
by: string;
descending: boolean;
};
}[];
query?: string;
filters?: ConditionType[];
page_number: number;
Expand Down Expand Up @@ -117,7 +117,12 @@ export const DataFramePlugin = createPlugin<S>("marimo-dataframe")
.input(
z.object({
sort: z
.object({ by: z.string(), descending: z.boolean() })
.array(
z.object({
by: z.string(),
descending: z.boolean(),
}),
)
.optional(),
query: z.string().optional(),
filters: z.array(ConditionSchema).optional(),
Expand Down
Loading
Loading