diff --git a/dask_sql/physical/utils/sort.py b/dask_sql/physical/utils/sort.py index 5022b5cd4..0e4cc9d85 100644 --- a/dask_sql/physical/utils/sort.py +++ b/dask_sql/physical/utils/sort.py @@ -50,33 +50,27 @@ def apply_sort( except ValueError: pass - # Dask doesn't natively support multi-column sorting; - # we work around this by initially sorting by the first - # column then handling the rest with `map_partitions` - df = df.sort_values( + # if standard `sort_values` can't handle ascending / null position params, + # we extend it using our custom sort function + return df.sort_values( by=sort_columns[0], ascending=sort_ascending[0], na_position="first" if sort_null_first[0] else "last", + sort_function=make_pickable_without_dask_sql(sort_partition_func), + sort_function_kwargs={ + "sort_columns": sort_columns, + "sort_ascending": sort_ascending, + "sort_null_first": sort_null_first, + }, ).persist() - # sort the remaining columns if given - if len(sort_columns) > 1: - df = df.map_partitions( - make_pickable_without_dask_sql(sort_partition_func), - meta=df, - sort_columns=sort_columns, - sort_ascending=sort_ascending, - sort_null_first=sort_null_first, - ).persist() - - return df - def sort_partition_func( partition: pd.DataFrame, sort_columns: List[str], sort_ascending: List[bool], sort_null_first: List[bool], + **kwargs, ): if partition.empty: return partition