Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 12 additions & 55 deletions dask_sql/physical/utils/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def apply_sort(
by=sort_columns,
ascending=sort_ascending,
na_position="first" if sort_null_first[0] else "last",
)
).persist()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noting that I added this persist call after sorting to match up with the other sort paths, which persist after sorting.


# dask-cudf only supports ascending sort / nulls last:
# https://github.com/rapidsai/cudf/pull/9250
Expand All @@ -37,34 +37,30 @@ def apply_sort(
and not any(sort_null_first)
):
try:
return df.sort_values(sort_columns, ignore_index=True)
return df.sort_values(sort_columns, ignore_index=True).persist()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

except ValueError:
pass

# Split the first column. We need to handle this one with set_index
first_sort_column = sort_columns[0]
first_sort_ascending = sort_ascending[0]
first_null_first = sort_null_first[0]

# Only sort by first column first
# As sorting is rather expensive, we bether persist here
df = df.persist()
df = _sort_first_column(
df, first_sort_column, first_sort_ascending, first_null_first
)
# Dask doesn't natively support multi-column sorting;
# we work around this by initially sorting by the first
# column then handling the rest with `map_partitions`
df = df.sort_values(
by=sort_columns[0],
ascending=sort_ascending[0],
na_position="first" if sort_null_first[0] else "last",
).persist()

# sort the remaining columns if given
if len(sort_columns) > 1:
df = df.persist()
df = df.map_partitions(
make_pickable_without_dask_sql(sort_partition_func),
meta=df,
sort_columns=sort_columns,
sort_ascending=sort_ascending,
sort_null_first=sort_null_first,
)
).persist()

return df.persist()
return df


def sort_partition_func(
Expand Down Expand Up @@ -94,42 +90,3 @@ def sort_partition_func(
)

return partition


def _sort_first_column(df, first_sort_column, first_sort_ascending, first_null_first):
# Dask can only sort if there are no NaNs in the first column.
# Therefore we need to do a single pass over the dataframe
# to check if we have NaNs in the first column
# If this is the case, we concat the NaN values to the front
# That might be a very complex operation and should
# in general be avoided
col = df[first_sort_column]
is_na = col.isna().persist()
if is_na.any().compute():
df_is_na = df[is_na].reset_index(drop=True).repartition(1)
df_not_is_na = (
df[~is_na].set_index(first_sort_column, drop=False).reset_index(drop=True)
)
else:
df_is_na = None
df_not_is_na = df.set_index(first_sort_column, drop=False).reset_index(
drop=True
)
if not first_sort_ascending:
# As set_index().reset_index() always sorts ascending, we need to reverse
# the order inside all partitions and the order of the partitions itself
# We do not need to do this for the nan-partitions
df_not_is_na = df_not_is_na.map_partitions(
lambda partition: partition[::-1], meta=df
)
df_not_is_na = df_not_is_na.partitions[::-1]

if df_is_na is not None:
if first_null_first:
df = dd.concat([df_is_na, df_not_is_na])
else:
df = dd.concat([df_not_is_na, df_is_na])
else:
df = df_not_is_na

return df