Skip to content

Commit 0c05787

Browse files
authored
Remove null-splitting from _perform_aggregation (#326)
* Remove null-splitting from _perform_aggregation * Sort groupby results to match Postgres output * Add back slow path for groupby * Don't group nulls first for slow path * Remove unnecessary inverse comment * Split up split_every test into CPU and GPU * Consolidate CPU / GPU tests again
1 parent a3fc92d commit 0c05787

4 files changed

Lines changed: 56 additions & 24 deletions

File tree

dask_sql/physical/rel/logical/aggregate.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -381,30 +381,44 @@ def _perform_aggregation(
381381
):
382382
tmp_df = df
383383

384+
# format aggregations for Dask; also check if we can use fast path for
385+
# groupby, which is only supported if we are not using any custom aggregations
386+
aggregations_dict = defaultdict(dict)
387+
fast_groupby = True
388+
for aggregation in aggregations:
389+
input_col, output_col, aggregation_f = aggregation
390+
aggregations_dict[input_col][output_col] = aggregation_f
391+
if not isinstance(aggregation_f, str):
392+
fast_groupby = False
393+
394+
# filter dataframe if specified
384395
if filter_column:
385396
filter_expression = tmp_df[filter_column]
386397
tmp_df = tmp_df[filter_expression]
387-
388398
logger.debug(f"Filtered by {filter_column} before aggregation.")
389399

390-
group_columns = [tmp_df[group_column] for group_column in group_columns]
391-
group_columns_and_nulls = get_groupby_with_nulls_cols(
392-
tmp_df, group_columns, additional_column_name
393-
)
394-
grouped_df = tmp_df.groupby(by=group_columns_and_nulls)
395-
396-
# Convert into the correct format for dask
397-
aggregations_dict = defaultdict(dict)
398-
for aggregation in aggregations:
399-
input_col, output_col, aggregation_f = aggregation
400+
# we might need a temporary column name if no groupby columns are specified
401+
if additional_column_name is None:
402+
additional_column_name = new_temporary_column(df)
400403

401-
aggregations_dict[input_col][output_col] = aggregation_f
404+
# perform groupby operation; if we are using custom aggreagations, we must handle
405+
# null values manually (this is slow)
406+
if fast_groupby:
407+
grouped_df = tmp_df.groupby(
408+
by=(group_columns or [additional_column_name]), dropna=False
409+
)
410+
else:
411+
group_columns = [tmp_df[group_column] for group_column in group_columns]
412+
group_columns_and_nulls = get_groupby_with_nulls_cols(
413+
tmp_df, group_columns, additional_column_name
414+
)
415+
grouped_df = tmp_df.groupby(by=group_columns_and_nulls)
402416

403-
# Now apply the aggregation
417+
# apply the aggregation(s)
404418
logger.debug(f"Performing aggregation {dict(aggregations_dict)}")
405419
agg_result = grouped_df.agg(aggregations_dict, **groupby_agg_options)
406420

407-
# ... fix the column names to a single level ...
421+
# fix the column names to a single level
408422
agg_result.columns = agg_result.columns.get_level_values(-1)
409423

410424
return agg_result

dask_sql/physical/utils/groupby.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ def get_groupby_with_nulls_cols(
1919

2020
group_columns_and_nulls = []
2121
for group_column in group_columns:
22-
# the ~ makes NaN come first
23-
is_null_column = ~(group_column.isnull())
22+
is_null_column = group_column.isnull()
2423
non_nan_group_column = group_column.fillna(0)
2524

2625
# split_out doesn't work if both columns have the same name

tests/integration/test_compatibility.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,9 @@ def test_agg_count():
304304
COUNT(DISTINCT d) AS cd_d,
305305
COUNT(e) AS c_e,
306306
COUNT(DISTINCT a) AS cd_e
307-
FROM a GROUP BY a, b
307+
FROM a GROUP BY a, b ORDER BY
308+
a NULLS FIRST,
309+
b NULLS FIRST
308310
""",
309311
a=a,
310312
)
@@ -354,7 +356,9 @@ def test_agg_sum_avg():
354356
AVG(e) AS avg_e,
355357
SUM(a)+AVG(e) AS mix_1,
356358
SUM(a+e) AS mix_2
357-
FROM a GROUP BY a,b
359+
FROM a GROUP BY a, b ORDER BY
360+
a NULLS FIRST,
361+
b NULLS FIRST
358362
""",
359363
a=a,
360364
)
@@ -423,7 +427,9 @@ def test_agg_min_max():
423427
MAX(g) AS max_g,
424428
MIN(a+e) AS mix_1,
425429
MIN(a)+MIN(e) AS mix_2
426-
FROM a GROUP BY a, b
430+
FROM a GROUP BY a, b ORDER BY
431+
a NULLS FIRST,
432+
b NULLS FIRST
427433
""",
428434
a=a,
429435
)

tests/integration/test_groupby.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -375,16 +375,27 @@ def test_groupby_split_out(c, input_table, split_out, request):
375375
c.drop_config("dask.groupby.aggregate.split_out")
376376

377377

378-
@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
379-
@pytest.mark.parametrize("split_every,expected_keys", [(2, 154), (3, 148), (4, 144)])
378+
@pytest.mark.parametrize(
379+
"gpu,split_every,expected_keys",
380+
[
381+
(False, 2, 74),
382+
(False, 3, 68),
383+
(False, 4, 64),
384+
pytest.param(True, 2, 91, marks=pytest.mark.gpu),
385+
pytest.param(True, 3, 85, marks=pytest.mark.gpu),
386+
pytest.param(True, 4, 81, marks=pytest.mark.gpu),
387+
],
388+
)
380389
def test_groupby_split_every(c, gpu, split_every, expected_keys):
381390
xd = pytest.importorskip("cudf") if gpu else pd
382391
input_ddf = dd.from_pandas(
383392
xd.DataFrame({"user_id": [1, 2, 3, 4] * 16, "b": [5, 6, 7, 8] * 16}),
384393
npartitions=16,
385394
) # Need an input with multiple partitions to demonstrate split_every
395+
386396
c.create_table("split_every_input", input_ddf)
387397
c.set_config(("dask.groupby.aggregate.split_every", split_every))
398+
388399
df = c.sql(
389400
"""
390401
SELECT
@@ -397,10 +408,12 @@ def test_groupby_split_every(c, gpu, split_every, expected_keys):
397408
input_ddf.groupby(by="user_id")
398409
.agg({"b": "sum"}, split_every=split_every)
399410
.reset_index(drop=False)
411+
.rename(columns={"b": "S"})
412+
.sort_values("user_id")
400413
)
401-
expected_df = expected_df.rename(columns={"b": "S"})
402-
expected_df = expected_df.sort_values("user_id")
414+
403415
assert len(df.dask.keys()) == expected_keys
404-
dd.assert_eq(df.compute().sort_values("user_id"), expected_df, check_index=False)
416+
dd.assert_eq(df, expected_df, check_index=False)
417+
405418
c.drop_config("dask.groupby.aggregate.split_every")
406419
c.drop_table("split_every_input")

0 commit comments

Comments
 (0)