Skip to content
42 changes: 28 additions & 14 deletions dask_sql/physical/rel/logical/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,30 +381,44 @@ def _perform_aggregation(
):
tmp_df = df

# format aggregations for Dask; also check if we can use fast path for
# groupby, which is only supported if we are not using any custom aggregations
aggregations_dict = defaultdict(dict)
fast_groupby = True
for aggregation in aggregations:
input_col, output_col, aggregation_f = aggregation
aggregations_dict[input_col][output_col] = aggregation_f
if not isinstance(aggregation_f, str):
fast_groupby = False

# filter dataframe if specified
if filter_column:
filter_expression = tmp_df[filter_column]
tmp_df = tmp_df[filter_expression]

logger.debug(f"Filtered by {filter_column} before aggregation.")

group_columns = [tmp_df[group_column] for group_column in group_columns]
group_columns_and_nulls = get_groupby_with_nulls_cols(
tmp_df, group_columns, additional_column_name
)
grouped_df = tmp_df.groupby(by=group_columns_and_nulls)

# Convert into the correct format for dask
aggregations_dict = defaultdict(dict)
for aggregation in aggregations:
input_col, output_col, aggregation_f = aggregation
# we might need a temporary column name if no groupby columns are specified
if additional_column_name is None:
additional_column_name = new_temporary_column(df)

aggregations_dict[input_col][output_col] = aggregation_f
# perform groupby operation; if we are using custom aggreagations, we must handle
# null values manually (this is slow)
if fast_groupby:
grouped_df = tmp_df.groupby(
by=(group_columns or [additional_column_name]), dropna=False
)
else:
group_columns = [tmp_df[group_column] for group_column in group_columns]
group_columns_and_nulls = get_groupby_with_nulls_cols(
tmp_df, group_columns, additional_column_name
)
grouped_df = tmp_df.groupby(by=group_columns_and_nulls)

# Now apply the aggregation
# apply the aggregation(s)
logger.debug(f"Performing aggregation {dict(aggregations_dict)}")
agg_result = grouped_df.agg(aggregations_dict, **groupby_agg_options)

# ... fix the column names to a single level ...
# fix the column names to a single level
agg_result.columns = agg_result.columns.get_level_values(-1)

return agg_result
3 changes: 1 addition & 2 deletions dask_sql/physical/utils/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def get_groupby_with_nulls_cols(

group_columns_and_nulls = []
for group_column in group_columns:
# the ~ makes NaN come first
is_null_column = ~(group_column.isnull())
is_null_column = group_column.isnull()
non_nan_group_column = group_column.fillna(0)

# split_out doesn't work if both columns have the same name
Expand Down
12 changes: 9 additions & 3 deletions tests/integration/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ def test_agg_count():
COUNT(DISTINCT d) AS cd_d,
COUNT(e) AS c_e,
COUNT(DISTINCT a) AS cd_e
FROM a GROUP BY a, b
FROM a GROUP BY a, b ORDER BY
a NULLS FIRST,
b NULLS FIRST
""",
a=a,
)
Expand Down Expand Up @@ -354,7 +356,9 @@ def test_agg_sum_avg():
AVG(e) AS avg_e,
SUM(a)+AVG(e) AS mix_1,
SUM(a+e) AS mix_2
FROM a GROUP BY a,b
FROM a GROUP BY a, b ORDER BY
a NULLS FIRST,
b NULLS FIRST
""",
a=a,
)
Expand Down Expand Up @@ -423,7 +427,9 @@ def test_agg_min_max():
MAX(g) AS max_g,
MIN(a+e) AS mix_1,
MIN(a)+MIN(e) AS mix_2
FROM a GROUP BY a, b
FROM a GROUP BY a, b ORDER BY
a NULLS FIRST,
b NULLS FIRST
""",
a=a,
)
Expand Down
23 changes: 18 additions & 5 deletions tests/integration/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,16 +375,27 @@ def test_groupby_split_out(c, input_table, split_out, request):
c.drop_config("dask.groupby.aggregate.split_out")


@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
@pytest.mark.parametrize("split_every,expected_keys", [(2, 154), (3, 148), (4, 144)])
@pytest.mark.parametrize(
"gpu,split_every,expected_keys",
[
(False, 2, 74),
(False, 3, 68),
(False, 4, 64),
pytest.param(True, 2, 91, marks=pytest.mark.gpu),
pytest.param(True, 3, 85, marks=pytest.mark.gpu),
pytest.param(True, 4, 81, marks=pytest.mark.gpu),
],
)
def test_groupby_split_every(c, gpu, split_every, expected_keys):
xd = pytest.importorskip("cudf") if gpu else pd
input_ddf = dd.from_pandas(
xd.DataFrame({"user_id": [1, 2, 3, 4] * 16, "b": [5, 6, 7, 8] * 16}),
npartitions=16,
) # Need an input with multiple partitions to demonstrate split_every

c.create_table("split_every_input", input_ddf)
c.set_config(("dask.groupby.aggregate.split_every", split_every))

df = c.sql(
"""
SELECT
Expand All @@ -397,10 +408,12 @@ def test_groupby_split_every(c, gpu, split_every, expected_keys):
input_ddf.groupby(by="user_id")
.agg({"b": "sum"}, split_every=split_every)
.reset_index(drop=False)
.rename(columns={"b": "S"})
.sort_values("user_id")
)
expected_df = expected_df.rename(columns={"b": "S"})
expected_df = expected_df.sort_values("user_id")

assert len(df.dask.keys()) == expected_keys
dd.assert_eq(df.compute().sort_values("user_id"), expected_df, check_index=False)
dd.assert_eq(df, expected_df, check_index=False)

c.drop_config("dask.groupby.aggregate.split_every")
c.drop_table("split_every_input")