diff --git a/dask_sql/physical/utils/sort.py b/dask_sql/physical/utils/sort.py index 5857a3321..2305ef1b6 100644 --- a/dask_sql/physical/utils/sort.py +++ b/dask_sql/physical/utils/sort.py @@ -19,30 +19,13 @@ def apply_sort( sort_null_first: List[bool], ) -> dd.DataFrame: # if we have a single partition, we can sometimes sort with map_partitions - if df.npartitions == 1: - if dask_cudf is not None and isinstance(df, dask_cudf.DataFrame): - # cudf only supports null positioning if `ascending` is a single boolean: - # https://github.com/rapidsai/cudf/issues/9400 - if (all(sort_ascending) or not any(sort_ascending)) and not any( - sort_null_first[1:] - ): - return df.map_partitions( - M.sort_values, - by=sort_columns, - ascending=all(sort_ascending), - na_position="first" if sort_null_first[0] else "last", - ) - if not any(sort_null_first): - return df.map_partitions( - M.sort_values, by=sort_columns, ascending=sort_ascending - ) - elif not any(sort_null_first[1:]): - return df.map_partitions( - M.sort_values, - by=sort_columns, - ascending=sort_ascending, - na_position="first" if sort_null_first[0] else "last", - ) + if df.npartitions == 1 and (all(sort_null_first) or not any(sort_null_first)): + return df.map_partitions( + M.sort_values, + by=sort_columns, + ascending=sort_ascending, + na_position="first" if sort_null_first[0] else "last", + ) # dask-cudf only supports ascending sort / nulls last: # https://github.com/rapidsai/cudf/pull/9250 diff --git a/tests/integration/test_sort.py b/tests/integration/test_sort.py index 07cf33609..6d03220f6 100644 --- a/tests/integration/test_sort.py +++ b/tests/integration/test_sort.py @@ -216,23 +216,50 @@ def test_sort_with_nan_more_columns(gpu): ) c.create_table("df", df) - df_result = ( - c.sql( - "SELECT * FROM df ORDER BY a ASC NULLS FIRST, b DESC NULLS LAST, c ASC NULLS FIRST" - ) - .c.compute() - .reset_index(drop=True) + df_result = c.sql( + "SELECT * FROM df ORDER BY a ASC NULLS FIRST, b DESC NULLS LAST, c ASC NULLS FIRST" + ) + dd.assert_eq( + df_result, + xd.DataFrame( + { + "a": [float("nan"), float("nan"), 1, 1, 2, 2], + "b": [float("inf"), 5, 1, 1, 2, float("nan")], + "c": [5, 6, float("nan"), 1, 3, 4], + } + ), + check_index=False, ) - dd.assert_eq(df_result, xd.Series([5, 6, float("nan"), 1, 3, 4]), check_names=False) - df_result = ( - c.sql( - "SELECT * FROM df ORDER BY a ASC NULLS LAST, b DESC NULLS FIRST, c DESC NULLS LAST" - ) - .c.compute() - .reset_index(drop=True) + df_result = c.sql( + "SELECT * FROM df ORDER BY a ASC NULLS LAST, b DESC NULLS FIRST, c DESC NULLS LAST" + ) + dd.assert_eq( + df_result, + xd.DataFrame( + { + "a": [1, 1, 2, 2, float("nan"), float("nan")], + "b": [1, 1, float("nan"), 2, float("inf"), 5], + "c": [1, float("nan"), 4, 3, 5, 6], + } + ), + check_index=False, + ) + + df_result = c.sql( + "SELECT * FROM df ORDER BY a ASC NULLS FIRST, b DESC NULLS LAST, c DESC NULLS LAST" + ) + dd.assert_eq( + df_result, + xd.DataFrame( + { + "a": [float("nan"), float("nan"), 1, 1, 2, 2], + "b": [float("inf"), 5, 1, 1, 2, float("nan")], + "c": [5, 6, 1, float("nan"), 3, 4], + } + ), + check_index=False, ) - dd.assert_eq(df_result, xd.Series([1, float("nan"), 4, 3, 5, 6]), check_names=False) @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])