diff --git a/dask_sql/context.py b/dask_sql/context.py index 2b3565f58..a5443f019 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -531,6 +531,71 @@ def register_model( schema_name = schema_name or self.schema_name self.schema[schema_name].models[model_name.lower()] = (model, training_columns) + def set_config( + self, + config_options: Union[Tuple[str, Any], Dict[str, Any]], + schema_name: str = None, + ): + """ + Add configuration options to a schema. + A configuration option could be used to set the behavior of certain configurirable operations. + + Eg: `dask.groupby.agg.split_out` can be used to split the output of a groupby agrregation to multiple partitions. + + Args: + config_options (:obj:`Tuple[str,val]` or :obj:`Dict[str,val]`): config_option and value to set + schema_name (:obj:`str`): Optionally select schema for setting configs + + Example: + .. code-block:: python + + from dask_sql import Context + + c = Context() + c.set_config(("dask.groupby.aggregate.split_out", 1)) + c.set_config( + { + "dask.groupby.aggregate.split_out": 2, + "dask.groupby.aggregate.split_every": 4, + } + ) + + """ + schema_name = schema_name or self.schema_name + self.schema[schema_name].config.set_config(config_options) + + def drop_config( + self, config_strs: Union[str, List[str]], schema_name: str = None, + ): + """ + Drop user set configuration options from schema + + Args: + config_strs (:obj:`str` or :obj:`List[str]`): config key or keys to drop + schema_name (:obj:`str`): Optionally select schema for dropping configs + + Example: + .. code-block:: python + + from dask_sql import Context + + c = Context() + c.set_config( + { + "dask.groupby.aggregate.split_out": 2, + "dask.groupby.aggregate.split_every": 4, + } + ) + c.drop_config( + [ + "dask.groupby.aggregate.split_out", + "dask.groupby.aggregate.split_every", + ] + ) + """ + schema_name = schema_name or self.schema_name + self.schema[schema_name].config.drop_config(config_strs) + def ipython_magic(self, auto_include=False): # pragma: no cover """ Register a new ipython/jupyter magic function "sql" diff --git a/dask_sql/datacontainer.py b/dask_sql/datacontainer.py index c69dee79c..6515fe350 100644 --- a/dask_sql/datacontainer.py +++ b/dask_sql/datacontainer.py @@ -230,3 +230,88 @@ def __init__(self, name: str): self.models: Dict[str, Tuple[Any, List[str]]] = {} self.functions: Dict[str, UDF] = {} self.function_lists: List[FunctionDescription] = [] + self.config: ConfigContainer = ConfigContainer() + + +class ConfigContainer: + """ + Helper class that contains configuration options required for specific operations + Configurations are stored in a dictionary where keys strings are delimited by `.` + for easier nested access of multiple configurations + Example: + Dask groupby aggregate operations can be configured via with the `split_out` option + to determine number of output partitions or the `split_every` option to determine + the number of partitions used during the groupby tree reduction step. + """ + + def __init__(self): + self.config_dict = { + # Do not set defaults here unless needed + # This mantains the list of configuration options supported that can be set + # "dask.groupby.aggregate.split_out": 1, + # "dask.groupby.aggregate.split_every": None, + } + + def set_config(self, config_options: Union[Tuple[str, Any], Dict[str, Any]]): + """ + Accepts either a tuple of (config, val) or a dictionary containing multiple + {config1: val1, config2: val2} pairs and updates the schema config with these values + """ + if isinstance(config_options, tuple): + config_options = [config_options] + self.config_dict.update(config_options) + + def drop_config(self, config_strs: Union[str, List[str]]): + if isinstance(config_strs, str): + config_strs = [config_strs] + for config_key in config_strs: + self.config_dict.pop(config_key) + + def get_config_by_prefix(self, config_prefix: str): + """ + Returns all configuration options matching the prefix in `config_prefix` + + Example: + .. code-block:: python + + from dask_sql.datacontainer import ConfigContainer + + sql_config = ConfigContainer() + sql_config.set_config( + { + "dask.groupby.aggregate.split_out":1, + "dask.groupby.aggregate.split_every": 1, + "dask.sort.persist": True, + } + ) + + sql_config.get_config_by_prefix("dask.groupby") + # Returns { + # "dask.groupby.aggregate.split_out": 1, + # "dask.groupby.aggregate.split_every": 1 + # } + + sql_config.get_config_by_prefix("dask") + # Returns { + # "dask.groupby.aggregate.split_out": 1, + # "dask.groupby.aggregate.split_every": 1, + # "dask.sort.persist": True + # } + + sql_config.get_config_by_prefix("dask.sort") + # Returns {"dask.sort.persist": True} + + sql_config.get_config_by_prefix("missing.key") + sql_config.get_config_by_prefix(None) + # Both return {} + + """ + return ( + { + key: val + for key, val in self.config_dict.items() + if key.startswith(config_prefix) + } + if config_prefix + else {} + ) diff --git a/dask_sql/physical/rel/logical/aggregate.py b/dask_sql/physical/rel/logical/aggregate.py index 88ecdd978..83c19fca0 100644 --- a/dask_sql/physical/rel/logical/aggregate.py +++ b/dask_sql/physical/rel/logical/aggregate.py @@ -231,6 +231,15 @@ def _do_aggregations( for col in group_columns: collected_aggregations[None].append((col, col, "first")) + groupby_agg_options = context.schema[ + context.schema_name + ].config.get_config_by_prefix("dask.groupby.aggregate") + # Update the config string to only include the actual param value + # i.e. dask.groupby.aggregate.split_out -> split_out + for config_key in list(groupby_agg_options.keys()): + groupby_agg_options[ + config_key.rpartition(".")[2] + ] = groupby_agg_options.pop(config_key) # Now we can go ahead and use these grouped aggregations # to perform the actual aggregation # It is very important to start with the non-filtered entry. @@ -240,13 +249,23 @@ def _do_aggregations( if key in collected_aggregations: aggregations = collected_aggregations.pop(key) df_result = self._perform_aggregation( - df, None, aggregations, additional_column_name, group_columns, + df, + None, + aggregations, + additional_column_name, + group_columns, + groupby_agg_options, ) # Now we can also the the rest for filter_column, aggregations in collected_aggregations.items(): agg_result = self._perform_aggregation( - df, filter_column, aggregations, additional_column_name, group_columns, + df, + filter_column, + aggregations, + additional_column_name, + group_columns, + groupby_agg_options, ) # ... and finally concat the new data with the already present columns @@ -358,6 +377,7 @@ def _perform_aggregation( aggregations: List[Tuple[str, str, Any]], additional_column_name: str, group_columns: List[str], + groupby_agg_options: Dict[str, Any] = {}, ): tmp_df = df @@ -382,7 +402,7 @@ def _perform_aggregation( # Now apply the aggregation logger.debug(f"Performing aggregation {dict(aggregations_dict)}") - agg_result = grouped_df.agg(aggregations_dict) + agg_result = grouped_df.agg(aggregations_dict, **groupby_agg_options) # ... fix the column names to a single level ... agg_result.columns = agg_result.columns.get_level_values(-1) diff --git a/dask_sql/physical/utils/groupby.py b/dask_sql/physical/utils/groupby.py index 5a4e78246..33479bd20 100644 --- a/dask_sql/physical/utils/groupby.py +++ b/dask_sql/physical/utils/groupby.py @@ -23,6 +23,9 @@ def get_groupby_with_nulls_cols( is_null_column = ~(group_column.isnull()) non_nan_group_column = group_column.fillna(0) + # split_out doesn't work if both columns have the same name + is_null_column.name = f"{is_null_column.name}_{new_temporary_column(df)}" + group_columns_and_nulls += [is_null_column, non_nan_group_column] if not group_columns_and_nulls: diff --git a/tests/integration/test_groupby.py b/tests/integration/test_groupby.py index 91907249e..87efa84f9 100644 --- a/tests/integration/test_groupby.py +++ b/tests/integration/test_groupby.py @@ -1,5 +1,7 @@ import numpy as np import pandas as pd +import pytest +from dask import dataframe as dd from pandas.testing import assert_frame_equal, assert_series_equal @@ -345,3 +347,60 @@ def test_stats_aggregation(c, timeseries_df): check_dtype=False, check_names=False, ) + + +@pytest.mark.parametrize( + "input_table", + ["user_table_1", pytest.param("gpu_user_table_1", marks=pytest.mark.gpu),], +) +@pytest.mark.parametrize("split_out", [None, 2, 4]) +def test_groupby_split_out(c, input_table, split_out, request): + user_table = request.getfixturevalue(input_table) + c.set_config(("dask.groupby.aggregate.split_out", split_out)) + df = c.sql( + f""" + SELECT + user_id, SUM(b) AS "S" + FROM {input_table} + GROUP BY user_id + """ + ) + expected_df = ( + user_table.groupby(by="user_id").agg({"b": "sum"}).reset_index(drop=False) + ) + expected_df = expected_df.rename(columns={"b": "S"}) + expected_df = expected_df.sort_values("user_id") + assert df.npartitions == split_out if split_out else 1 + dd.assert_eq(df.compute().sort_values("user_id"), expected_df, check_index=False) + c.drop_config("dask.groupby.aggregate.split_out") + + +@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) +@pytest.mark.parametrize("split_every,expected_keys", [(2, 154), (3, 148), (4, 144)]) +def test_groupby_split_every(c, gpu, split_every, expected_keys): + xd = pytest.importorskip("cudf") if gpu else pd + input_ddf = dd.from_pandas( + xd.DataFrame({"user_id": [1, 2, 3, 4] * 16, "b": [5, 6, 7, 8] * 16}), + npartitions=16, + ) # Need an input with multiple partitions to demonstrate split_every + c.create_table("split_every_input", input_ddf) + c.set_config(("dask.groupby.aggregate.split_every", split_every)) + df = c.sql( + """ + SELECT + user_id, SUM(b) AS "S" + FROM split_every_input + GROUP BY user_id + """ + ) + expected_df = ( + input_ddf.groupby(by="user_id") + .agg({"b": "sum"}, split_every=split_every) + .reset_index(drop=False) + ) + expected_df = expected_df.rename(columns={"b": "S"}) + expected_df = expected_df.sort_values("user_id") + assert len(df.dask.keys()) == expected_keys + dd.assert_eq(df.compute().sort_values("user_id"), expected_df, check_index=False) + c.drop_config("dask.groupby.aggregate.split_every") + c.drop_table("split_every_input")