diff --git a/dask_sql/physical/rel/logical/aggregate.py b/dask_sql/physical/rel/logical/aggregate.py index 83c19fca0..cfdaca82e 100644 --- a/dask_sql/physical/rel/logical/aggregate.py +++ b/dask_sql/physical/rel/logical/aggregate.py @@ -381,30 +381,44 @@ def _perform_aggregation( ): tmp_df = df + # format aggregations for Dask; also check if we can use fast path for + # groupby, which is only supported if we are not using any custom aggregations + aggregations_dict = defaultdict(dict) + fast_groupby = True + for aggregation in aggregations: + input_col, output_col, aggregation_f = aggregation + aggregations_dict[input_col][output_col] = aggregation_f + if not isinstance(aggregation_f, str): + fast_groupby = False + + # filter dataframe if specified if filter_column: filter_expression = tmp_df[filter_column] tmp_df = tmp_df[filter_expression] - logger.debug(f"Filtered by {filter_column} before aggregation.") - group_columns = [tmp_df[group_column] for group_column in group_columns] - group_columns_and_nulls = get_groupby_with_nulls_cols( - tmp_df, group_columns, additional_column_name - ) - grouped_df = tmp_df.groupby(by=group_columns_and_nulls) - - # Convert into the correct format for dask - aggregations_dict = defaultdict(dict) - for aggregation in aggregations: - input_col, output_col, aggregation_f = aggregation + # we might need a temporary column name if no groupby columns are specified + if additional_column_name is None: + additional_column_name = new_temporary_column(df) - aggregations_dict[input_col][output_col] = aggregation_f + # perform groupby operation; if we are using custom aggreagations, we must handle + # null values manually (this is slow) + if fast_groupby: + grouped_df = tmp_df.groupby( + by=(group_columns or [additional_column_name]), dropna=False + ) + else: + group_columns = [tmp_df[group_column] for group_column in group_columns] + group_columns_and_nulls = get_groupby_with_nulls_cols( + tmp_df, group_columns, additional_column_name + ) + grouped_df = tmp_df.groupby(by=group_columns_and_nulls) - # Now apply the aggregation + # apply the aggregation(s) logger.debug(f"Performing aggregation {dict(aggregations_dict)}") agg_result = grouped_df.agg(aggregations_dict, **groupby_agg_options) - # ... fix the column names to a single level ... + # fix the column names to a single level agg_result.columns = agg_result.columns.get_level_values(-1) return agg_result diff --git a/dask_sql/physical/utils/groupby.py b/dask_sql/physical/utils/groupby.py index 33479bd20..97070bdd0 100644 --- a/dask_sql/physical/utils/groupby.py +++ b/dask_sql/physical/utils/groupby.py @@ -19,8 +19,7 @@ def get_groupby_with_nulls_cols( group_columns_and_nulls = [] for group_column in group_columns: - # the ~ makes NaN come first - is_null_column = ~(group_column.isnull()) + is_null_column = group_column.isnull() non_nan_group_column = group_column.fillna(0) # split_out doesn't work if both columns have the same name diff --git a/tests/integration/test_compatibility.py b/tests/integration/test_compatibility.py index a63974c95..d5b785a55 100644 --- a/tests/integration/test_compatibility.py +++ b/tests/integration/test_compatibility.py @@ -304,7 +304,9 @@ def test_agg_count(): COUNT(DISTINCT d) AS cd_d, COUNT(e) AS c_e, COUNT(DISTINCT a) AS cd_e - FROM a GROUP BY a, b + FROM a GROUP BY a, b ORDER BY + a NULLS FIRST, + b NULLS FIRST """, a=a, ) @@ -354,7 +356,9 @@ def test_agg_sum_avg(): AVG(e) AS avg_e, SUM(a)+AVG(e) AS mix_1, SUM(a+e) AS mix_2 - FROM a GROUP BY a,b + FROM a GROUP BY a, b ORDER BY + a NULLS FIRST, + b NULLS FIRST """, a=a, ) @@ -423,7 +427,9 @@ def test_agg_min_max(): MAX(g) AS max_g, MIN(a+e) AS mix_1, MIN(a)+MIN(e) AS mix_2 - FROM a GROUP BY a, b + FROM a GROUP BY a, b ORDER BY + a NULLS FIRST, + b NULLS FIRST """, a=a, ) diff --git a/tests/integration/test_groupby.py b/tests/integration/test_groupby.py index 87efa84f9..1281ff027 100644 --- a/tests/integration/test_groupby.py +++ b/tests/integration/test_groupby.py @@ -375,16 +375,27 @@ def test_groupby_split_out(c, input_table, split_out, request): c.drop_config("dask.groupby.aggregate.split_out") -@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) -@pytest.mark.parametrize("split_every,expected_keys", [(2, 154), (3, 148), (4, 144)]) +@pytest.mark.parametrize( + "gpu,split_every,expected_keys", + [ + (False, 2, 74), + (False, 3, 68), + (False, 4, 64), + pytest.param(True, 2, 91, marks=pytest.mark.gpu), + pytest.param(True, 3, 85, marks=pytest.mark.gpu), + pytest.param(True, 4, 81, marks=pytest.mark.gpu), + ], +) def test_groupby_split_every(c, gpu, split_every, expected_keys): xd = pytest.importorskip("cudf") if gpu else pd input_ddf = dd.from_pandas( xd.DataFrame({"user_id": [1, 2, 3, 4] * 16, "b": [5, 6, 7, 8] * 16}), npartitions=16, ) # Need an input with multiple partitions to demonstrate split_every + c.create_table("split_every_input", input_ddf) c.set_config(("dask.groupby.aggregate.split_every", split_every)) + df = c.sql( """ SELECT @@ -397,10 +408,12 @@ def test_groupby_split_every(c, gpu, split_every, expected_keys): input_ddf.groupby(by="user_id") .agg({"b": "sum"}, split_every=split_every) .reset_index(drop=False) + .rename(columns={"b": "S"}) + .sort_values("user_id") ) - expected_df = expected_df.rename(columns={"b": "S"}) - expected_df = expected_df.sort_values("user_id") + assert len(df.dask.keys()) == expected_keys - dd.assert_eq(df.compute().sort_values("user_id"), expected_df, check_index=False) + dd.assert_eq(df, expected_df, check_index=False) + c.drop_config("dask.groupby.aggregate.split_every") c.drop_table("split_every_input")