Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 7 additions & 24 deletions dask_sql/physical/utils/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,13 @@ def apply_sort(
sort_null_first: List[bool],
) -> dd.DataFrame:
# if we have a single partition, we can sometimes sort with map_partitions
if df.npartitions == 1:
if dask_cudf is not None and isinstance(df, dask_cudf.DataFrame):
# cudf only supports null positioning if `ascending` is a single boolean:
# https://github.com/rapidsai/cudf/issues/9400
if (all(sort_ascending) or not any(sort_ascending)) and not any(
sort_null_first[1:]
):
return df.map_partitions(
M.sort_values,
by=sort_columns,
ascending=all(sort_ascending),
na_position="first" if sort_null_first[0] else "last",
)
if not any(sort_null_first):
return df.map_partitions(
M.sort_values, by=sort_columns, ascending=sort_ascending
)
elif not any(sort_null_first[1:]):
return df.map_partitions(
M.sort_values,
by=sort_columns,
ascending=sort_ascending,
na_position="first" if sort_null_first[0] else "last",
)
if df.npartitions == 1 and (all(sort_null_first) or not any(sort_null_first)):
return df.map_partitions(
M.sort_values,
by=sort_columns,
ascending=sort_ascending,
na_position="first" if sort_null_first[0] else "last",
)

# dask-cudf only supports ascending sort / nulls last:
# https://github.com/rapidsai/cudf/pull/9250
Expand Down
55 changes: 41 additions & 14 deletions tests/integration/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,23 +216,50 @@ def test_sort_with_nan_more_columns(gpu):
)
c.create_table("df", df)

df_result = (
c.sql(
"SELECT * FROM df ORDER BY a ASC NULLS FIRST, b DESC NULLS LAST, c ASC NULLS FIRST"
)
.c.compute()
.reset_index(drop=True)
df_result = c.sql(
"SELECT * FROM df ORDER BY a ASC NULLS FIRST, b DESC NULLS LAST, c ASC NULLS FIRST"
)
dd.assert_eq(
df_result,
xd.DataFrame(
{
"a": [float("nan"), float("nan"), 1, 1, 2, 2],
"b": [float("inf"), 5, 1, 1, 2, float("nan")],
"c": [5, 6, float("nan"), 1, 3, 4],
}
),
check_index=False,
)
dd.assert_eq(df_result, xd.Series([5, 6, float("nan"), 1, 3, 4]), check_names=False)

df_result = (
c.sql(
"SELECT * FROM df ORDER BY a ASC NULLS LAST, b DESC NULLS FIRST, c DESC NULLS LAST"
)
.c.compute()
.reset_index(drop=True)
df_result = c.sql(
"SELECT * FROM df ORDER BY a ASC NULLS LAST, b DESC NULLS FIRST, c DESC NULLS LAST"
)
dd.assert_eq(
df_result,
xd.DataFrame(
{
"a": [1, 1, 2, 2, float("nan"), float("nan")],
"b": [1, 1, float("nan"), 2, float("inf"), 5],
"c": [1, float("nan"), 4, 3, 5, 6],
}
),
check_index=False,
)

df_result = c.sql(
"SELECT * FROM df ORDER BY a ASC NULLS FIRST, b DESC NULLS LAST, c DESC NULLS LAST"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old logic that we were using to decide if we could do single-partition sorts (not any(sort_null_first[1:])) was incorrect, but we weren't catching that in this test because we weren't checking the full dataframe's sorting. This new test should fail with the old logic but pass with this PR.

)
dd.assert_eq(
df_result,
xd.DataFrame(
{
"a": [float("nan"), float("nan"), 1, 1, 2, 2],
"b": [float("inf"), 5, 1, 1, 2, float("nan")],
"c": [5, 6, 1, float("nan"), 3, 4],
}
),
check_index=False,
)
dd.assert_eq(df_result, xd.Series([1, float("nan"), 4, 3, 5, 6]), check_names=False)


@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
Expand Down