diff --git a/frontend/src/components/data-table/__tests__/header-items.test.tsx b/frontend/src/components/data-table/__tests__/header-items.test.tsx new file mode 100644 index 00000000000..3f0db55f950 --- /dev/null +++ b/frontend/src/components/data-table/__tests__/header-items.test.tsx @@ -0,0 +1,117 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import type { SortingState } from "@tanstack/react-table"; +import { describe, expect, it, vi } from "vitest"; + +describe("multi-column sorting logic", () => { + // Extract the core sorting logic to test in isolation + const handleSort = ( + columnId: string, + desc: boolean, + sortingState: SortingState, + setSorting: (state: SortingState) => void, + clearSorting: () => void, + ) => { + const currentSort = sortingState.find((s) => s.id === columnId); + + if (currentSort && currentSort.desc === desc) { + // Clicking the same sort again - remove it + clearSorting(); + } else { + // New sort or different direction - move to end of stack + const otherSorts = sortingState.filter((s) => s.id !== columnId); + const newSort = { id: columnId, desc }; + setSorting([...otherSorts, newSort]); + } + }; + + it("implements stack-based sorting: moves re-clicked column to end", () => { + const sortingState: SortingState = [ + { id: "name", desc: false }, + { id: "age", desc: false }, + ]; + const setSorting = vi.fn(); + const clearSorting = vi.fn(); + + // Click Desc on age - should move age to end with desc=true + handleSort("age", true, sortingState, setSorting, clearSorting); + + expect(setSorting).toHaveBeenCalledWith([ + { id: "name", desc: false }, + { id: "age", desc: true }, + ]); + expect(clearSorting).not.toHaveBeenCalled(); + }); + + it("removes sort when clicking same direction twice", () => { + const sortingState: SortingState = [{ id: "age", desc: false }]; + const setSorting = vi.fn(); + const clearSorting = vi.fn(); + + // Click Asc on age again - should remove the sort + handleSort("age", false, sortingState, setSorting, clearSorting); + + expect(clearSorting).toHaveBeenCalled(); + expect(setSorting).not.toHaveBeenCalled(); + }); + + it("adds new column to end of stack", () => { + const sortingState: SortingState = [{ id: "name", desc: false }]; + const setSorting = vi.fn(); + const clearSorting = vi.fn(); + + // Click Asc on age - should add age to end + handleSort("age", false, sortingState, setSorting, clearSorting); + + expect(setSorting).toHaveBeenCalledWith([ + { id: "name", desc: false }, + { id: "age", desc: false }, + ]); + expect(clearSorting).not.toHaveBeenCalled(); + }); + + it("toggles sort direction when clicking opposite", () => { + const sortingState: SortingState = [{ id: "age", desc: false }]; + const setSorting = vi.fn(); + const clearSorting = vi.fn(); + + // Click Desc on age - should toggle to descending + handleSort("age", true, sortingState, setSorting, clearSorting); + + expect(setSorting).toHaveBeenCalledWith([{ id: "age", desc: true }]); + expect(clearSorting).not.toHaveBeenCalled(); + }); + + it("correctly calculates priority numbers", () => { + const sortingState: SortingState = [ + { id: "name", desc: false }, + { id: "age", desc: true }, + { id: "dept", desc: false }, + ]; + + // Priority is index + 1 + const nameSort = sortingState.find((s) => s.id === "name"); + const namePriority = nameSort ? sortingState.indexOf(nameSort) + 1 : null; + expect(namePriority).toBe(1); + + const deptSort = sortingState.find((s) => s.id === "dept"); + const deptPriority = deptSort ? sortingState.indexOf(deptSort) + 1 : null; + expect(deptPriority).toBe(3); + }); + + it("handles removing column from middle of stack", () => { + const sortingState: SortingState = [ + { id: "name", desc: false }, + { id: "age", desc: true }, + { id: "dept", desc: false }, + ]; + const setSorting = vi.fn(); + const clearSorting = vi.fn(); + + // Click Desc on age again - should remove it + handleSort("age", true, sortingState, setSorting, clearSorting); + + expect(clearSorting).toHaveBeenCalled(); + // After removal, dept should move from priority 3 to priority 2 + }); +}); diff --git a/frontend/src/components/data-table/column-header.tsx b/frontend/src/components/data-table/column-header.tsx index d95072652aa..20224fd2815 100644 --- a/frontend/src/components/data-table/column-header.tsx +++ b/frontend/src/components/data-table/column-header.tsx @@ -1,7 +1,7 @@ /* Copyright 2024 Marimo. All rights reserved. */ "use no memo"; -import type { Column } from "@tanstack/react-table"; +import type { Column, Table } from "@tanstack/react-table"; import { capitalize } from "lodash-es"; import { FilterIcon, MinusIcon, TextIcon, XIcon } from "lucide-react"; import { useMemo, useRef, useState } from "react"; @@ -68,6 +68,7 @@ interface DataTableColumnHeaderProps column: Column; header: React.ReactNode; calculateTopKRows?: CalculateTopKRows; + table?: Table; } export const DataTableColumnHeader = ({ @@ -75,6 +76,7 @@ export const DataTableColumnHeader = ({ header, className, calculateTopKRows, + table, }: DataTableColumnHeaderProps) => { const [isFilterValueOpen, setIsFilterValueOpen] = useState(false); const { locale } = useLocale(); @@ -117,7 +119,7 @@ export const DataTableColumnHeader = ({ {renderDataType(column)} - {renderSorts(column)} + {renderSorts(column, table)} {renderCopyColumn(column)} {renderColumnPinning(column)} {renderColumnWrapping(column)} diff --git a/frontend/src/components/data-table/columns.tsx b/frontend/src/components/data-table/columns.tsx index d29021c15d0..a46d2d75192 100644 --- a/frontend/src/components/data-table/columns.tsx +++ b/frontend/src/components/data-table/columns.tsx @@ -165,7 +165,7 @@ export function generateColumns({ return row[key as keyof T]; }, - header: ({ column }) => { + header: ({ column, table }) => { const stats = chartSpecModel?.getColumnStats(key); const dtype = column.columnDef.meta?.dtype; const headerTitle = headerTooltip?.[key]; @@ -208,6 +208,7 @@ export function generateColumns({ header={headerWithTooltip} column={column} calculateTopKRows={calculateTopKRows} + table={table} /> ); diff --git a/frontend/src/components/data-table/data-table.tsx b/frontend/src/components/data-table/data-table.tsx index cd015795084..21fd4000531 100644 --- a/frontend/src/components/data-table/data-table.tsx +++ b/frontend/src/components/data-table/data-table.tsx @@ -215,8 +215,13 @@ const DataTableInternal = ({ manualPagination: manualPagination, getPaginationRowModel: getPaginationRowModel(), // sorting - ...(setSorting ? { onSortingChange: setSorting } : {}), + ...(setSorting + ? { + onSortingChange: setSorting, + } + : {}), manualSorting: manualSorting, + enableMultiSort: true, getSortedRowModel: getSortedRowModel(), // filtering manualFiltering: true, diff --git a/frontend/src/components/data-table/header-items.tsx b/frontend/src/components/data-table/header-items.tsx index 500e5270364..7f1de70f3af 100644 --- a/frontend/src/components/data-table/header-items.tsx +++ b/frontend/src/components/data-table/header-items.tsx @@ -1,7 +1,7 @@ /* Copyright 2024 Marimo. All rights reserved. */ import { PinLeftIcon, PinRightIcon } from "@radix-ui/react-icons"; -import type { Column } from "@tanstack/react-table"; +import type { Column, SortDirection, Table } from "@tanstack/react-table"; import { AlignJustifyIcon, ArrowDownWideNarrowIcon, @@ -163,27 +163,80 @@ export function renderCopyColumn(column: Column) { const AscIcon = ArrowUpNarrowWideIcon; const DescIcon = ArrowDownWideNarrowIcon; -export function renderSorts(column: Column) { +export function renderSorts( + column: Column, + table?: Table, +) { if (!column.getCanSort()) { return null; } + const sortDirection = column.getIsSorted(); + const sortingIndex = column.getSortIndex(); + + const sortingState = table?.getState().sorting; + const hasMultiSort = sortingState?.length && sortingState.length > 1; + + const renderSortIndex = () => { + return ( + {sortingIndex + 1} + ); + }; + + const renderClearSort = () => { + if (!sortDirection) { + return null; + } + + if (!hasMultiSort) { + // render clear sort for this column + return ( + column.clearSorting()}> + + Clear sort + + ); + } + + // render clear sort for all columns + return ( + table?.resetSorting()}> + + Clear all sorts + + ); + }; + + const toggleSort = (direction: SortDirection) => { + // Clear sort if clicking the same direction + if (sortDirection === direction) { + column.clearSorting(); + } else { + // Toggle sort direction + const descending = direction === "desc"; + column.toggleSorting(descending, true); + } + }; + return ( <> - column.toggleSorting(false)}> + toggleSort("asc")} + className={sortDirection === "asc" ? "bg-accent" : ""} + > Asc + {sortDirection === "asc" && renderSortIndex()} - column.toggleSorting(true)}> + toggleSort("desc")} + className={sortDirection === "desc" ? "bg-accent" : ""} + > Desc + {sortDirection === "desc" && renderSortIndex()} - {column.getIsSorted() && ( - column.clearSorting()}> - - Clear sort - - )} + {renderClearSort()} ); diff --git a/frontend/src/plugins/impl/DataTablePlugin.tsx b/frontend/src/plugins/impl/DataTablePlugin.tsx index 15fa13c5d8d..fae3d027ad1 100644 --- a/frontend/src/plugins/impl/DataTablePlugin.tsx +++ b/frontend/src/plugins/impl/DataTablePlugin.tsx @@ -205,7 +205,7 @@ type DataTableFunctions = { sort?: { by: string; descending: boolean; - }; + }[]; query?: string; filters?: ConditionType[]; page_number: number; @@ -298,7 +298,12 @@ export const DataTablePlugin = createPlugin("marimo-table") .input( z.object({ sort: z - .object({ by: z.string(), descending: z.boolean() }) + .array( + z.object({ + by: z.string(), + descending: z.boolean(), + }), + ) .optional(), query: z.string().optional(), filters: z.array(ConditionSchema).optional(), @@ -501,19 +506,15 @@ export const LoadingDataTableComponent = memo( !props.lazy && !pageSizeChanged; - if (sorting.length > 1) { - Logger.warn("Multiple sort columns are not supported"); - } + // Convert sorting state to API format + const sortArgs = + sorting.length > 0 + ? sorting.map((s) => ({ by: s.id, descending: s.desc })) + : undefined; // If we have sort/search/filter, use the search function const searchResultsPromise = search({ - sort: - sorting.length > 0 - ? { - by: sorting[0].id, - descending: sorting[0].desc, - } - : undefined, + sort: sortArgs, query: searchQuery, page_number: paginationState.pageIndex, page_size: paginationState.pageSize, @@ -563,16 +564,15 @@ export const LoadingDataTableComponent = memo( const getRow = useCallback( async (rowId: number) => { + const sortArgs = + sorting.length > 0 + ? sorting.map((s) => ({ by: s.id, descending: s.desc })) + : undefined; + const result = await search({ page_number: rowId, page_size: 1, - sort: - sorting.length > 0 - ? { - by: sorting[0].id, - descending: sorting[0].desc, - } - : undefined, + sort: sortArgs, query: searchQuery, filters: filters.flatMap((filter) => { return filterToFilterCondition( diff --git a/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx b/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx index b3018dcc28b..e0a2f1d43b6 100644 --- a/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx +++ b/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx @@ -63,7 +63,7 @@ type PluginFunctions = { sort?: { by: string; descending: boolean; - }; + }[]; query?: string; filters?: ConditionType[]; page_number: number; @@ -117,7 +117,12 @@ export const DataFramePlugin = createPlugin("marimo-dataframe") .input( z.object({ sort: z - .object({ by: z.string(), descending: z.boolean() }) + .array( + z.object({ + by: z.string(), + descending: z.boolean(), + }), + ) .optional(), query: z.string().optional(), filters: z.array(ConditionSchema).optional(), diff --git a/marimo/_plugins/ui/_impl/dataframes/dataframe.py b/marimo/_plugins/ui/_impl/dataframes/dataframe.py index f987659e4b9..d97427a1aac 100644 --- a/marimo/_plugins/ui/_impl/dataframes/dataframe.py +++ b/marimo/_plugins/ui/_impl/dataframes/dataframe.py @@ -303,15 +303,18 @@ def _download_as(self, args: DownloadAsArgs) -> str: def _apply_filters_query_sort( self, query: Optional[str], - sort: Optional[SortArgs], + sort: Optional[list[SortArgs]], ) -> TableManager[Any]: result = self._get_cached_table_manager(self._value, self._limit) if query: result = result.search(query) - if sort and sort.by in result.get_column_names(): - result = result.sort_values(sort.by, sort.descending) + if sort: + existing_columns = set(result.get_column_names()) + valid_sort = [s for s in sort if s.by in existing_columns] + if valid_sort: + result = result.sort_values(valid_sort) return result diff --git a/marimo/_plugins/ui/_impl/table.py b/marimo/_plugins/ui/_impl/table.py index 8d00c85110c..4e9adf9839c 100644 --- a/marimo/_plugins/ui/_impl/table.py +++ b/marimo/_plugins/ui/_impl/table.py @@ -116,12 +116,18 @@ class ColumnSummaries: MaxColumnsType = Union[int, None, MaxColumnsNotProvided] +@dataclass(frozen=True) +class SortArgs: + by: ColumnName + descending: bool + + @dataclass(frozen=True) class SearchTableArgs: page_size: int page_number: int query: Optional[str] = None - sort: Optional[SortArgs] = None + sort: Optional[list[SortArgs]] = None filters: Optional[list[Condition]] = None limit: Optional[int] = None max_columns: Optional[Union[int, MaxColumnsNotProvided]] = ( @@ -143,12 +149,6 @@ class SearchTableResponse: ] = None -@dataclass(frozen=True) -class SortArgs: - by: ColumnName - descending: bool - - @dataclass class GetRowIdsResponse: row_ids: list[int] @@ -1135,18 +1135,22 @@ def _get_data_url(self, args: EmptyArgs) -> GetDataUrlResponse: @functools.lru_cache(maxsize=1) # noqa: B019 def _apply_filters_query_sort_cached( self, - filters: Optional[list[Condition]], + filters: Optional[tuple[Condition, ...]], query: Optional[str], - sort: Optional[SortArgs], + sort: Optional[tuple[SortArgs, ...]], ) -> TableManager[Any]: """Cached version that expects hashable arguments.""" - return self._apply_filters_query_sort(filters, query, sort) + return self._apply_filters_query_sort( + list(filters) if filters else None, + query, + list(sort) if sort else None, + ) def _apply_filters_query_sort( self, filters: Optional[list[Condition]], query: Optional[str], - sort: Optional[SortArgs], + sort: Optional[list[SortArgs]], ) -> TableManager[Any]: result = self._manager @@ -1175,8 +1179,11 @@ def _apply_filters_query_sort( if query: result = result.search(query) - if sort and sort.by in result.get_column_names(): - result = result.sort_values(sort.by, sort.descending) + if sort: + existing_columns = set(result.get_column_names()) + valid_sort = [s for s in sort if s.by in existing_columns] + if valid_sort: + result = result.sort_values(valid_sort) return result @@ -1355,7 +1362,7 @@ def clamp_rows_and_columns(manager: TableManager[Any]) -> str: result = filter_function( tuple(args.filters) if args.filters else None, # type: ignore args.query, - args.sort, + tuple(args.sort) if args.sort else None, # type: ignore ) # Save the manager to be used for selection diff --git a/marimo/_plugins/ui/_impl/tables/default_table.py b/marimo/_plugins/ui/_impl/tables/default_table.py index 7342d24caed..5654a345c7d 100644 --- a/marimo/_plugins/ui/_impl/tables/default_table.py +++ b/marimo/_plugins/ui/_impl/tables/default_table.py @@ -4,7 +4,7 @@ import functools from collections import defaultdict from collections.abc import Sequence -from typing import Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast from marimo._data.models import BinValue, ColumnStats, ExternalDataType from marimo._dependencies.dependencies import DependencyManager @@ -17,6 +17,9 @@ format_column, format_row, ) + +if TYPE_CHECKING: + from marimo._plugins.ui._impl.table import SortArgs from marimo._plugins.ui._impl.tables.pandas_table import ( PandasTableManagerFactory, ) @@ -361,74 +364,82 @@ def get_unique_column_values(self, column: str) -> list[str | int | float]: def get_sample_values(self, column: str) -> list[Any]: return self._as_table_manager().get_sample_values(column) - def sort_values( - self, by: ColumnName, descending: bool - ) -> DefaultTableManager: + def sort_values(self, by: list[SortArgs]) -> DefaultTableManager: + if not by: + return self + if isinstance(self.data, dict) and self.is_column_oriented: - # For column-oriented data, extract the sort column and get sorted indices - sort_column = cast(list[Any], self.data[by]) - try: - sorted_indices = sorted( - range(len(sort_column)), - key=lambda i: sort_column[i], - reverse=descending, - ) - except TypeError: - # Handle when values are not comparable - def sort_func_str(i: int) -> tuple[bool, str] | str: - # For ascending, generate a tuple of (is_none, value) - # (True, None) will be for None values - # (False, x) will be for other values. - # As False < True, None values will be sorted to the end. - - # For descending, (is_not_none, value) tuple - # (False, None) will be for None values. - # (True, x) will be for other values. - # As True > False, other values come before None values - if descending: - return ( - sort_column[i] is not None, - str(sort_column[i]), - ) - else: - return str(sort_column[i]) + # Column-oriented: sort indices, then reorder all columns + data_dict = cast(dict[str, list[Any]], self.data) + first_column = next(iter(data_dict.values())) + num_rows = len(first_column) + indices = list(range(num_rows)) + + # Apply sorts in reverse order for stable multi-column sorting + for sort_arg in reversed(by): + values = data_dict[sort_arg.by] + + # Separate None and non-None indices + none_indices = [i for i in indices if values[i] is None] + non_none_indices = [ + i for i in indices if values[i] is not None + ] + + # Try natural comparison first, fall back to string on mixed types + try: + non_none_indices = sorted( + non_none_indices, + key=lambda i: values[i], + reverse=sort_arg.descending, + ) + except TypeError: + # Mixed types - use string comparison + non_none_indices = sorted( + non_none_indices, + key=lambda i: str(values[i]), + reverse=sort_arg.descending, + ) + + # None values always go last + indices = non_none_indices + none_indices - sorted_indices = sorted( - range(len(sort_column)), - key=sort_func_str, - reverse=descending, - ) - # Apply sorted indices to each column while maintaining column orientation return DefaultTableManager( cast( JsonTableData, { - col: [ - cast(list[Any], values)[i] for i in sorted_indices - ] - for col, values in self.data.items() + col: [col_values[i] for i in indices] + for col, col_values in data_dict.items() }, ) ) - # For row-major data, continue with existing logic - normalized = self._normalize_data(self.data) - try: + # Row-oriented: sort rows directly + data = self._normalize_data(self.data) + for sort_arg in reversed(by): + # Separate None and non-None rows + none_rows = [row for row in data if row[sort_arg.by] is None] + non_none_rows = [ + row for row in data if row[sort_arg.by] is not None + ] - def sort_func_col(x: dict[str, Any]) -> tuple[bool, Any]: - is_none = x[by] is not None if descending else x[by] is None - return (is_none, x[by]) + # Try natural comparison first, fall back to string on mixed types + try: + non_none_rows = sorted( + non_none_rows, + key=lambda row: row[sort_arg.by], + reverse=sort_arg.descending, + ) + except TypeError: + # Mixed types - use string comparison + non_none_rows = sorted( + non_none_rows, + key=lambda row: str(row[sort_arg.by]), + reverse=sort_arg.descending, + ) - data = sorted(normalized, key=sort_func_col, reverse=descending) - except TypeError: - # Handle when all values are not comparable - def sort_func_col_str(x: dict[str, Any]) -> tuple[bool, str]: - is_none = x[by] is not None if descending else x[by] is None - return (is_none, str(x[by])) + # None values always go last + data = non_none_rows + none_rows - data = sorted( - normalized, key=sort_func_col_str, reverse=descending - ) return DefaultTableManager(data) @staticmethod diff --git a/marimo/_plugins/ui/_impl/tables/narwhals_table.py b/marimo/_plugins/ui/_impl/tables/narwhals_table.py index d3ae8b2ec4e..41ffe317c45 100644 --- a/marimo/_plugins/ui/_impl/tables/narwhals_table.py +++ b/marimo/_plugins/ui/_impl/tables/narwhals_table.py @@ -5,7 +5,7 @@ import functools import io from functools import cached_property -from typing import Any, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast import msgspec import narwhals.stable.v2 as nw @@ -41,6 +41,9 @@ unwrap_py_scalar, ) +if TYPE_CHECKING: + from marimo._plugins.ui._impl.table import SortArgs + LOGGER = _loggers.marimo_logger() UNSTABLE_API_WARNING = "`Series.hist` is being called from the stable API although considered an unstable feature." @@ -680,17 +683,17 @@ def to_primitive(value: Any) -> str | int | float: # May be metadata-only frame return [] - def sort_values( - self, by: ColumnName, descending: bool - ) -> TableManager[Any]: - if is_narwhals_lazyframe(self.data): - return self.with_new_data( - self.data.sort(by, descending=descending, nulls_last=True) - ) - else: - return self.with_new_data( - self.data.sort(by, descending=descending, nulls_last=True) - ) + def sort_values(self, by: list[SortArgs]) -> TableManager[Any]: + if not by: + return self + + # Extract columns and descending flags for Narwhals/Polars + columns = [sort_arg.by for sort_arg in by] + descending = [sort_arg.descending for sort_arg in by] + + return self.with_new_data( + self.data.sort(columns, descending=descending, nulls_last=True) + ) def __repr__(self) -> str: rows = self.get_num_rows(force=False) diff --git a/marimo/_plugins/ui/_impl/tables/table_manager.py b/marimo/_plugins/ui/_impl/tables/table_manager.py index b08c1eecb24..849e4c28b3e 100644 --- a/marimo/_plugins/ui/_impl/tables/table_manager.py +++ b/marimo/_plugins/ui/_impl/tables/table_manager.py @@ -3,7 +3,15 @@ import abc from dataclasses import dataclass -from typing import Any, Generic, NamedTuple, Optional, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Generic, + NamedTuple, + Optional, + TypeVar, + Union, +) from marimo._data.models import ( BinValue, @@ -13,6 +21,9 @@ ) from marimo._plugins.ui._impl.tables.format import FormatMapping +if TYPE_CHECKING: + from marimo._plugins.ui._impl.table import SortArgs + T = TypeVar("T") ColumnName = str @@ -79,9 +90,7 @@ def supports_filters(self) -> bool: pass @abc.abstractmethod - def sort_values( - self, by: ColumnName, descending: bool - ) -> TableManager[Any]: + def sort_values(self, by: list[SortArgs]) -> TableManager[Any]: pass @abc.abstractmethod diff --git a/tests/_plugins/ui/_impl/tables/test_default_table.py b/tests/_plugins/ui/_impl/tables/test_default_table.py index 736c1104d4d..aef8aeebf0e 100644 --- a/tests/_plugins/ui/_impl/tables/test_default_table.py +++ b/tests/_plugins/ui/_impl/tables/test_default_table.py @@ -11,7 +11,7 @@ from marimo._dependencies.dependencies import DependencyManager from marimo._output.hypertext import Html -from marimo._plugins.ui._impl.table import _validate_header_tooltip +from marimo._plugins.ui._impl.table import SortArgs, _validate_header_tooltip from marimo._plugins.ui._impl.tables.default_table import DefaultTableManager from marimo._plugins.ui._impl.tables.table_manager import ( TableCell, @@ -108,7 +108,9 @@ def test_take_out_of_bounds(self) -> None: assert limited_manager.data == [] def test_sort(self) -> None: - sorted_data = self.manager.sort_values(by="name", descending=True).data + sorted_data = self.manager.sort_values( + [SortArgs(by="name", descending=True)] + ).data expected_data = [ {"name": "Eve", "age": 22, "birth_year": date(2002, 1, 30)}, {"name": "Dave", "age": 28, "birth_year": date(1996, 3, 5)}, @@ -119,7 +121,7 @@ def test_sort(self) -> None: assert sorted_data == expected_data # reverse sort sorted_data = self.manager.sort_values( - by="name", descending=False + [SortArgs(by="name", descending=False)] ).data expected_data = [ {"name": "Alice", "age": 30, "birth_year": date(1994, 5, 24)}, @@ -135,7 +137,7 @@ def test_sort_null_values(self) -> None: data_with_nan[1]["age"] = None manager_with_nan = DefaultTableManager(data_with_nan) sorted_data = manager_with_nan.sort_values( - by="age", descending=False + [SortArgs(by="age", descending=False)] ).data last_row = sorted_data[-1] @@ -150,7 +152,7 @@ def test_sort_null_values(self) -> None: # descending sorted_data = manager_with_nan.sort_values( - by="age", descending=True + [SortArgs(by="age", descending=False)] ).data last_row = sorted_data[-1] assert last_row == expected_last_row @@ -160,29 +162,35 @@ def test_sort_null_values(self) -> None: data_with_strings[1]["name"] = None manager_with_strings = DefaultTableManager(data_with_strings) sorted_data = manager_with_strings.sort_values( - by="name", descending=False + [SortArgs(by="name", descending=False)] ).data assert sorted_data[-1]["name"] is None # strings descending sorted_data = manager_with_strings.sort_values( - by="name", descending=True + [SortArgs(by="name", descending=False)] ).data assert sorted_data[-1]["name"] is None def test_sort_single_values(self) -> None: manager = DefaultTableManager([1, 3, 2]) - sorted_data = manager.sort_values(by="value", descending=True).data + sorted_data = manager.sort_values( + [SortArgs(by="value", descending=True)] + ).data expected_data = [{"value": 3}, {"value": 2}, {"value": 1}] assert sorted_data == expected_data # reverse sort - sorted_data = manager.sort_values(by="value", descending=False).data + sorted_data = manager.sort_values( + [SortArgs(by="value", descending=False)] + ).data expected_data = [{"value": 1}, {"value": 2}, {"value": 3}] assert sorted_data == expected_data def test_mixed_values(self) -> None: manager = DefaultTableManager([1, "foo", 2, False]) - sorted_data = manager.sort_values(by="value", descending=True).data + sorted_data = manager.sort_values( + [SortArgs(by="value", descending=True)] + ).data expected_data = [ {"value": "foo"}, {"value": False}, @@ -191,7 +199,9 @@ def test_mixed_values(self) -> None: ] assert sorted_data == expected_data # reverse sort - sorted_data = manager.sort_values(by="value", descending=False).data + sorted_data = manager.sort_values( + [SortArgs(by="value", descending=False)] + ).data expected_data = [ {"value": 1}, {"value": 2}, @@ -200,6 +210,101 @@ def test_mixed_values(self) -> None: ] assert sorted_data == expected_data + def test_multi_column_sort_integers_then_strings(self) -> None: + """Test multi-column sorting with integers then strings.""" + data = [ + {"category": 1, "name": "Charlie"}, + {"category": 1, "name": "Alice"}, + {"category": 2, "name": "Bob"}, + ] + manager = DefaultTableManager(data) + + sorted_data = manager.sort_values( + by=[ + SortArgs(by="category", descending=False), + SortArgs(by="name", descending=False), + ] + ).data + expected_data = [ + {"category": 1, "name": "Alice"}, + {"category": 1, "name": "Charlie"}, + {"category": 2, "name": "Bob"}, + ] + assert sorted_data == expected_data + + def test_multi_column_sort_mixed_directions(self) -> None: + """Test multi-column sorting with mixed ascending/descending directions.""" + data = [ + {"priority": 1, "score": 85}, + {"priority": 1, "score": 90}, + {"priority": 2, "score": 70}, + ] + manager = DefaultTableManager(data) + + sorted_data = manager.sort_values( + by=[ + SortArgs(by="priority", descending=False), + SortArgs(by="score", descending=True), + ] + ).data + expected_data = [ + {"priority": 1, "score": 90}, + {"priority": 1, "score": 85}, + {"priority": 2, "score": 70}, + ] + assert sorted_data == expected_data + + def test_multi_column_sort_with_none_values(self) -> None: + """Test multi-column sorting with None values in secondary column.""" + data = [ + {"group": 1, "value": None}, + {"group": 1, "value": 10}, + {"group": 2, "value": 5}, + ] + manager = DefaultTableManager(data) + + sorted_data = manager.sort_values( + by=[ + SortArgs(by="group", descending=False), + SortArgs(by="value", descending=False), + ] + ).data + expected_data = [ + {"group": 1, "value": 10}, + {"group": 1, "value": None}, + {"group": 2, "value": 5}, + ] + assert sorted_data == expected_data + + def test_multi_column_sort_mixed_types_in_column(self) -> None: + """Test multi-column sorting with mixed types in a single column.""" + data = [ + {"id": 1, "value": "string"}, + {"id": 1, "value": 42}, + {"id": 2, "value": True}, + ] + manager = DefaultTableManager(data) + + # Should fall back to string comparison for mixed types + sorted_data = manager.sort_values( + by=[ + SortArgs(by="id", descending=False), + SortArgs(by="value", descending=False), + ] + ).data + expected_data = [ + {"id": 1, "value": 42}, + {"id": 1, "value": "string"}, + {"id": 2, "value": True}, + ] + assert sorted_data == expected_data + + def test_multi_column_sort_empty_list(self) -> None: + """Test that empty sort parameters return original data.""" + manager = DefaultTableManager(self.data) + sorted_data = manager.sort_values(by=[]).data + assert sorted_data == self.data + def test_search(self) -> None: searched_manager = self.manager.search("alice") expected_data = [ @@ -481,7 +586,9 @@ def test_take_out_of_bounds(self) -> None: assert limited_manager.data["name"] == [] def test_sort(self) -> None: - sorted_data = self.manager.sort_values(by="name", descending=True).data + sorted_data = self.manager.sort_values( + [SortArgs(by="name", descending=True)] + ).data expected_data = { "name": ["Eve", "Dave", "Charlie", "Bob", "Alice"], "age": [22, 28, 35, 25, 30], @@ -500,7 +607,7 @@ def test_sort_null_values(self) -> None: data_with_nan["age"][1] = None manager_with_nan = DefaultTableManager(data_with_nan) sorted_data = manager_with_nan.sort_values( - by="age", descending=False + [SortArgs(by="age", descending=False)] ).data assert sorted_data["age"][-1] is None @@ -508,7 +615,7 @@ def test_sort_null_values(self) -> None: # ascending sorted_data = manager_with_nan.sort_values( - by="age", descending=True + [SortArgs(by="age", descending=False)] ).data assert sorted_data["age"][-1] is None assert sorted_data["name"][-1] == "Bob" @@ -518,16 +625,62 @@ def test_sort_null_values(self) -> None: data_with_strings["name"][1] = None manager_with_strings = DefaultTableManager(data_with_strings) sorted_data = manager_with_strings.sort_values( - by="name", descending=False + by=[SortArgs(by="name", descending=False)] ).data assert sorted_data["name"][-1] is None # strings descending sorted_data = manager_with_strings.sort_values( - by="name", descending=True + by=[SortArgs(by="name", descending=False)] ).data assert sorted_data["name"][-1] is None + def test_multi_column_sort_columnar_integers_then_strings(self) -> None: + """Test multi-column sorting with columnar data - integers then strings.""" + data = { + "category": [2, 1, 1], + "name": ["Alice", "Charlie", "Bob"], + } + manager = DefaultTableManager(data) + + sorted_data = manager.sort_values( + by=[ + SortArgs(by="category", descending=False), + SortArgs(by="name", descending=False), + ] + ).data + expected_data = { + "category": [1, 1, 2], + "name": ["Bob", "Charlie", "Alice"], + } + assert sorted_data == expected_data + + def test_multi_column_sort_columnar_with_none_values(self) -> None: + """Test multi-column sorting with columnar data containing None values.""" + data = { + "group": [1, 1, 2], + "value": [None, 10, 5], + } + manager = DefaultTableManager(data) + + sorted_data = manager.sort_values( + by=[ + SortArgs(by="group", descending=False), + SortArgs(by="value", descending=False), + ] + ).data + expected_data = { + "group": [1, 1, 2], + "value": [10, None, 5], + } + assert sorted_data == expected_data + + def test_multi_column_sort_empty_list_columnar(self) -> None: + """Test that empty sort parameters return original data for columnar.""" + manager = DefaultTableManager(self.data) + sorted_data = manager.sort_values(by=[]).data + assert sorted_data == self.data + @pytest.mark.skipif( not HAS_DEPS, reason="optional dependencies not installed" ) @@ -823,7 +976,9 @@ def test_take_out_of_bounds(self) -> None: assert limited_manager.data == [] def test_sort(self) -> None: - sorted_manager = self.manager.sort_values(by="value", descending=True) + sorted_manager = self.manager.sort_values( + [SortArgs(by="value", descending=True)] + ) expected_data = [{"key": "b", "value": 2}, {"key": "a", "value": 1}] assert sorted_manager.data == expected_data @@ -832,7 +987,7 @@ def test_sort_null_values(self) -> None: data["b"] = None manager_with_nan = DefaultTableManager(data) sorted_data = manager_with_nan.sort_values( - by="value", descending=False + [SortArgs(by="value", descending=False)] ).data assert sorted_data == [ {"key": "a", "value": 1}, @@ -841,7 +996,7 @@ def test_sort_null_values(self) -> None: # descending sorted_data = manager_with_nan.sort_values( - by="value", descending=True + [SortArgs(by="value", descending=False)] ).data assert sorted_data == [ {"key": "a", "value": 1}, @@ -853,7 +1008,7 @@ def test_sort_null_values(self) -> None: {"a": "foo", "b": None, "c": "bar"} ) sorted_data = data_with_strings.sort_values( - by="value", descending=False + [SortArgs(by="value", descending=False)] ).data assert sorted_data == [ {"key": "c", "value": "bar"}, @@ -863,7 +1018,7 @@ def test_sort_null_values(self) -> None: # strings descending sorted_data = data_with_strings.sort_values( - by="value", descending=True + [SortArgs(by="value", descending=True)] ).data assert sorted_data == [ {"key": "a", "value": "foo"}, diff --git a/tests/_plugins/ui/_impl/tables/test_ibis_table.py b/tests/_plugins/ui/_impl/tables/test_ibis_table.py index 4f0e697522d..96accb5df60 100644 --- a/tests/_plugins/ui/_impl/tables/test_ibis_table.py +++ b/tests/_plugins/ui/_impl/tables/test_ibis_table.py @@ -9,6 +9,7 @@ from marimo._data.models import BinValue, ColumnStats from marimo._dependencies.dependencies import DependencyManager +from marimo._plugins.ui._impl.table import SortArgs from marimo._plugins.ui._impl.tables.format import FormatMapping from marimo._plugins.ui._impl.tables.ibis_table import ( IbisTableManagerFactory, @@ -227,7 +228,9 @@ def test_stats_string(self) -> None: ) def test_sort_values(self) -> None: - sorted_manager = self.manager.sort_values("A", descending=True) + sorted_manager = self.manager.sort_values( + [SortArgs(by="A", descending=True)] + ) assert sorted_manager.data.collect().to_dict(as_series=False) == { "A": [3, 2, 1], "B": ["c", "b", "a"], @@ -321,7 +324,9 @@ def test_sort_values_with_nulls(self) -> None: manager = self.factory.create()(table) # Descending true - sorted_manager = manager.sort_values("A", descending=True) + sorted_manager = manager.sort_values( + [SortArgs(by="A", descending=True)] + ) sorted_data = sorted_manager.data.collect().to_dict(as_series=False)[ "A" ] @@ -333,7 +338,9 @@ def test_sort_values_with_nulls(self) -> None: ] # Descending false - sorted_manager = manager.sort_values("A", descending=False) + sorted_manager = manager.sort_values( + [SortArgs(by="A", descending=False)] + ) sorted_data = sorted_manager.data.collect().to_dict(as_series=False)[ "A" ] diff --git a/tests/_plugins/ui/_impl/tables/test_narwhals.py b/tests/_plugins/ui/_impl/tables/test_narwhals.py index 6b9c91e32ae..8e4ee80f631 100644 --- a/tests/_plugins/ui/_impl/tables/test_narwhals.py +++ b/tests/_plugins/ui/_impl/tables/test_narwhals.py @@ -12,6 +12,7 @@ from marimo._data.models import BinValue, ColumnStats from marimo._dependencies.dependencies import DependencyManager +from marimo._plugins.ui._impl.table import SortArgs from marimo._plugins.ui._impl.tables.format import FormatMapping from marimo._plugins.ui._impl.tables.narwhals_table import ( NarwhalsTableManager, @@ -481,7 +482,9 @@ def test_get_stats_unwraps_scalars_properly(self) -> None: assert isinstance(bool_stats.false, int) def test_sort_values(self) -> None: - sorted_df = self.manager.sort_values("A", descending=True).data + sorted_df = self.manager.sort_values( + [SortArgs(by="A", descending=True)] + ).data expected_df = self.data.sort("A", descending=True) assert_frame_equal(sorted_df, expected_df) @@ -1231,14 +1234,14 @@ def test_search_with_regex(df: Any) -> None: def test_sort_values_with_nulls(df: Any) -> None: manager = NarwhalsTableManager.from_dataframe(df) sorted_manager: NarwhalsTableManager[Any] = manager.sort_values( - "A", descending=True + [SortArgs(by="A", descending=True)] ) assert sorted_manager.as_frame()["A"].head(3).to_list() == [3, 2, 1] last = unwrap_py_scalar(sorted_manager.as_frame()["A"].tail(1).item()) assert last is None or isnan(last) # ascending - sorted_manager = manager.sort_values("A", descending=False) + sorted_manager = manager.sort_values([SortArgs(by="A", descending=False)]) assert sorted_manager.as_frame()["A"].head(3).to_list() == [1, 2, 3] last = unwrap_py_scalar(sorted_manager.as_frame()["A"].tail(1).item()) assert last is None or isnan(last) diff --git a/tests/_plugins/ui/_impl/tables/test_pandas_table.py b/tests/_plugins/ui/_impl/tables/test_pandas_table.py index 1d31bc80032..18e9f330921 100644 --- a/tests/_plugins/ui/_impl/tables/test_pandas_table.py +++ b/tests/_plugins/ui/_impl/tables/test_pandas_table.py @@ -12,6 +12,7 @@ from marimo._data.models import ColumnStats from marimo._dependencies.dependencies import DependencyManager +from marimo._plugins.ui._impl.table import SortArgs from marimo._plugins.ui._impl.tables.format import FormatMapping from marimo._plugins.ui._impl.tables.pandas_table import ( PandasTableManagerFactory, @@ -670,7 +671,9 @@ def test_summary_does_fail_on_each_column(self) -> None: assert complex_data.get_stats(column) is not None def test_sort_values(self) -> None: - sorted_df = self.manager.sort_values("A", descending=True).data + sorted_df = self.manager.sort_values( + [SortArgs(by="A", descending=True)] + ).data expected_df = self.data.sort_values("A", ascending=False) assert_frame_equal(sorted_df, expected_df) @@ -683,7 +686,9 @@ def test_sort_values_with_index(self) -> None: ) data.index.name = "index" manager = self.factory.create()(data) - sorted_df = manager.sort_values("A", descending=True).data + sorted_df = manager.sort_values( + [SortArgs(by="A", descending=True)] + ).data assert sorted_df.to_native().index.tolist() == [3, 2, 1] def test_get_unique_column_values(self) -> None: @@ -1047,7 +1052,9 @@ def test_search_with_regex(self) -> None: def test_sort_values_with_nulls(self) -> None: df = pd.DataFrame({"A": [3, 1, None, 2]}) manager = self.factory.create()(df) - sorted_manager = manager.sort_values("A", descending=True) + sorted_manager = manager.sort_values( + [SortArgs(by="A", descending=True)] + ) assert sorted_manager.data["A"].to_list()[:-1] == [ 3.0, 2.0, @@ -1057,7 +1064,9 @@ def test_sort_values_with_nulls(self) -> None: assert last is None or isnan(last) # ascending - sorted_manager = manager.sort_values("A", descending=False) + sorted_manager = manager.sort_values( + [SortArgs(by="A", descending=False)] + ) assert sorted_manager.data["A"].to_list()[:-1] == [ 1.0, 2.0, diff --git a/tests/_plugins/ui/_impl/tables/test_polars_table.py b/tests/_plugins/ui/_impl/tables/test_polars_table.py index 7513c35fb36..b357f04108d 100644 --- a/tests/_plugins/ui/_impl/tables/test_polars_table.py +++ b/tests/_plugins/ui/_impl/tables/test_polars_table.py @@ -11,6 +11,7 @@ from marimo._data.models import ColumnStats from marimo._dependencies.dependencies import DependencyManager +from marimo._plugins.ui._impl.table import SortArgs from marimo._plugins.ui._impl.tables.format import FormatMapping from marimo._plugins.ui._impl.tables.polars_table import ( PolarsTableManagerFactory, @@ -458,7 +459,9 @@ def test_stats_does_fail_on_each_column(self) -> None: assert complex_data.get_stats(column) is not None def test_sort_values(self) -> None: - sorted_df = self.manager.sort_values("A", descending=True).data + sorted_df = self.manager.sort_values( + [SortArgs(by="A", descending=True)] + ).data expected_df = self.data.sort("A", descending=True) assert assert_frame_equal(sorted_df, expected_df) @@ -831,7 +834,9 @@ def test_sort_values_with_nulls(self) -> None: df = pl.DataFrame({"A": [3, 1, None, 2]}) manager = self.factory.create()(df) - sorted_manager = manager.sort_values("A", descending=True) + sorted_manager = manager.sort_values( + [SortArgs(by="A", descending=True)] + ) assert sorted_manager.data["A"].to_list()[:-1] == [ 3.0, 2.0, @@ -841,7 +846,9 @@ def test_sort_values_with_nulls(self) -> None: assert last is None or isnan(last) # ascending - sorted_manager = manager.sort_values("A", descending=False) + sorted_manager = manager.sort_values( + [SortArgs(by="A", descending=False)] + ) assert sorted_manager.data["A"].to_list()[:-1] == [ 1.0, 2.0, diff --git a/tests/_plugins/ui/_impl/tables/test_selection.py b/tests/_plugins/ui/_impl/tables/test_selection.py index 8f5c24a2eeb..dece9f42384 100644 --- a/tests/_plugins/ui/_impl/tables/test_selection.py +++ b/tests/_plugins/ui/_impl/tables/test_selection.py @@ -6,6 +6,7 @@ import pytest from marimo._dependencies.dependencies import DependencyManager +from marimo._plugins.ui._impl.table import SortArgs from marimo._plugins.ui._impl.tables.narwhals_table import NarwhalsTableManager from marimo._plugins.ui._impl.tables.selection import ( INDEX_COLUMN_NAME, @@ -68,7 +69,7 @@ def test_selection_with_index_column_and_sort(backend: Any): manager = NarwhalsTableManager(data) # Sort and select - sorted_data = manager.sort_values(by="age", descending=True) + sorted_data = manager.sort_values([SortArgs(by="age", descending=True)]) selected = sorted_data.select_rows([0, 2]) result = selected.data.to_dict(as_series=False) assert result[INDEX_COLUMN_NAME] == [2, 0] # Original indices preserved diff --git a/tests/_plugins/ui/_impl/test_table.py b/tests/_plugins/ui/_impl/test_table.py index 27f169a90a4..31b1ca4eff1 100644 --- a/tests/_plugins/ui/_impl/test_table.py +++ b/tests/_plugins/ui/_impl/test_table.py @@ -129,7 +129,9 @@ def test_normalize_data(executing_kernel: Kernel) -> None: def test_sort_1d_list_of_strings(dtm: DefaultTableManager) -> None: data = ["banana", "apple", "cherry", "date", "elderberry"] dtm.data = _normalize_data(data) - sorted_data = dtm.sort_values(by="value", descending=False).data + sorted_data = dtm.sort_values( + [SortArgs(by="value", descending=False)] + ).data expected_data = [ {"value": "apple"}, {"value": "banana"}, @@ -143,7 +145,9 @@ def test_sort_1d_list_of_strings(dtm: DefaultTableManager) -> None: def test_sort_1d_list_of_integers(dtm: DefaultTableManager) -> None: data = [42, 17, 23, 99, 8] dtm.data = _normalize_data(data) - sorted_data = dtm.sort_values(by="value", descending=False).data + sorted_data = dtm.sort_values( + [SortArgs(by="value", descending=False)] + ).data expected_data = [ {"value": 8}, {"value": 17}, @@ -163,10 +167,12 @@ def test_sort_list_of_dicts(dtm: DefaultTableManager) -> None: {"name": "Eve", "age": 22, "birth_year": date(2002, 1, 30)}, ] dtm.data = _normalize_data(data) - sorted_data = dtm.sort_values(by="age", descending=True).data + sorted_data = dtm.sort_values([SortArgs(by="age", descending=True)]).data with pytest.raises(KeyError): - _res = dtm.sort_values(by="missing_column", descending=True).data + _res = dtm.sort_values( + [SortArgs(by="missing_column", descending=True)] + ).data expected_data = [ {"name": "Charlie", "age": 35, "birth_year": date(1989, 12, 1)}, @@ -191,10 +197,14 @@ def test_sort_dict_of_lists(dtm: DefaultTableManager) -> None: "net_worth": [1000, 2000, 1500, 1800, 1700], } dtm.data = _normalize_data(data) - sorted_data = dtm.sort_values(by="net_worth", descending=False).data + sorted_data = dtm.sort_values( + [SortArgs(by="net_worth", descending=False)] + ).data with pytest.raises(KeyError): - _res = dtm.sort_values(by="missing_column", descending=True).data + _res = dtm.sort_values( + [SortArgs(by="missing_column", descending=True)] + ).data expected_data = { "company": [ @@ -219,10 +229,12 @@ def test_sort_dict_of_tuples(dtm: DefaultTableManager) -> None: "key5": (7, 9, 11), } dtm.data = _normalize_data(data) - sorted_data = dtm.sort_values(by="key1", descending=True).data + sorted_data = dtm.sort_values([SortArgs(by="key1", descending=True)]).data with pytest.raises(KeyError): - _res = dtm.sort_values(by="missing_column", descending=True).data + _res = dtm.sort_values( + [SortArgs(by="missing_column", descending=True)] + ).data expected_data = [ {"key1": 42, "key2": 99, "key3": 34, "key4": 1, "key5": 7}, @@ -293,7 +305,7 @@ def test_value_with_sorting_then_selection() -> None: table._search( SearchTableArgs( - sort=SortArgs("value", descending=True), + sort=[SortArgs(by="value", descending=True)], page_size=10, page_number=0, ) @@ -304,10 +316,7 @@ def test_value_with_sorting_then_selection() -> None: table._search( SearchTableArgs( - sort=SortArgs( - "value", - descending=False, - ), + sort=[SortArgs(by="value", descending=False)], page_size=10, page_number=0, ) @@ -330,7 +339,7 @@ def test_value_with_sorting_then_selection_dfs(df: Any) -> None: table = ui.table(df) table._search( SearchTableArgs( - sort=SortArgs("a", descending=True), + sort=[SortArgs(by="a", descending=True)], page_size=10, page_number=0, ) @@ -341,7 +350,7 @@ def test_value_with_sorting_then_selection_dfs(df: Any) -> None: table._search( SearchTableArgs( - sort=SortArgs("a", descending=False), + sort=[SortArgs(by="a", descending=False)], page_size=10, page_number=0, ) @@ -516,7 +525,7 @@ def test_value_with_selection_then_sorting_dict_of_lists() -> None: table._search( SearchTableArgs( - sort=SortArgs("net_worth", descending=True), + sort=[SortArgs(by="net_worth", descending=True)], page_size=10, page_number=0, ) @@ -559,7 +568,7 @@ def test_value_with_cell_selection_then_sorting_dict_of_lists() -> None: table._search( SearchTableArgs( - sort=SortArgs("net_worth", descending=True), + sort=[SortArgs(by="net_worth", descending=True)], page_size=10, page_number=0, ) @@ -589,7 +598,7 @@ def test_search_sort_nonexistent_columns() -> None: # no error raised table._search( SearchTableArgs( - sort=SortArgs("missing_column", descending=False), + sort=[SortArgs(by="missing_column", descending=False)], page_size=10, page_number=0, ) @@ -1727,7 +1736,7 @@ def always_green(_row, _col, _value): page_size=2, page_number=0, query="", - sort=SortArgs(by="column_0", descending=True), + sort=[SortArgs(by="column_0", descending=True)], ) ) # Sorted rows have reverse order of row_ids @@ -2010,7 +2019,7 @@ def test_max_columns_not_provided_with_sort(): search_args = SearchTableArgs( page_size=10, page_number=0, - sort=SortArgs(by="col0", descending=True), + sort=[SortArgs(by="col0", descending=True)], max_columns=MAX_COLUMNS_NOT_PROVIDED, ) response = table._search(search_args) @@ -2021,7 +2030,7 @@ def test_max_columns_not_provided_with_sort(): search_args = SearchTableArgs( page_size=10, page_number=0, - sort=SortArgs(by="col0", descending=True), + sort=[SortArgs(by="col0", descending=True)], max_columns=20, ) response = table._search(search_args) @@ -2032,7 +2041,7 @@ def test_max_columns_not_provided_with_sort(): search_args = SearchTableArgs( page_size=10, page_number=0, - sort=SortArgs(by="col0", descending=True), + sort=[SortArgs(by="col0", descending=True)], max_columns=None, ) response = table._search(search_args)