diff --git a/dask_sql/physical/utils/sort.py b/dask_sql/physical/utils/sort.py index 2305ef1b6..46b72b6c8 100644 --- a/dask_sql/physical/utils/sort.py +++ b/dask_sql/physical/utils/sort.py @@ -25,7 +25,7 @@ def apply_sort( by=sort_columns, ascending=sort_ascending, na_position="first" if sort_null_first[0] else "last", - ) + ).persist() # dask-cudf only supports ascending sort / nulls last: # https://github.com/rapidsai/cudf/pull/9250 @@ -37,34 +37,30 @@ def apply_sort( and not any(sort_null_first) ): try: - return df.sort_values(sort_columns, ignore_index=True) + return df.sort_values(sort_columns, ignore_index=True).persist() except ValueError: pass - # Split the first column. We need to handle this one with set_index - first_sort_column = sort_columns[0] - first_sort_ascending = sort_ascending[0] - first_null_first = sort_null_first[0] - - # Only sort by first column first - # As sorting is rather expensive, we bether persist here - df = df.persist() - df = _sort_first_column( - df, first_sort_column, first_sort_ascending, first_null_first - ) + # Dask doesn't natively support multi-column sorting; + # we work around this by initially sorting by the first + # column then handling the rest with `map_partitions` + df = df.sort_values( + by=sort_columns[0], + ascending=sort_ascending[0], + na_position="first" if sort_null_first[0] else "last", + ).persist() # sort the remaining columns if given if len(sort_columns) > 1: - df = df.persist() df = df.map_partitions( make_pickable_without_dask_sql(sort_partition_func), meta=df, sort_columns=sort_columns, sort_ascending=sort_ascending, sort_null_first=sort_null_first, - ) + ).persist() - return df.persist() + return df def sort_partition_func( @@ -94,42 +90,3 @@ def sort_partition_func( ) return partition - - -def _sort_first_column(df, first_sort_column, first_sort_ascending, first_null_first): - # Dask can only sort if there are no NaNs in the first column. - # Therefore we need to do a single pass over the dataframe - # to check if we have NaNs in the first column - # If this is the case, we concat the NaN values to the front - # That might be a very complex operation and should - # in general be avoided - col = df[first_sort_column] - is_na = col.isna().persist() - if is_na.any().compute(): - df_is_na = df[is_na].reset_index(drop=True).repartition(1) - df_not_is_na = ( - df[~is_na].set_index(first_sort_column, drop=False).reset_index(drop=True) - ) - else: - df_is_na = None - df_not_is_na = df.set_index(first_sort_column, drop=False).reset_index( - drop=True - ) - if not first_sort_ascending: - # As set_index().reset_index() always sorts ascending, we need to reverse - # the order inside all partitions and the order of the partitions itself - # We do not need to do this for the nan-partitions - df_not_is_na = df_not_is_na.map_partitions( - lambda partition: partition[::-1], meta=df - ) - df_not_is_na = df_not_is_na.partitions[::-1] - - if df_is_na is not None: - if first_null_first: - df = dd.concat([df_is_na, df_not_is_na]) - else: - df = dd.concat([df_not_is_na, df_is_na]) - else: - df = df_not_is_na - - return df