From 469f170bbcbc40367912f719a79ff7fd1acb0f36 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Wed, 24 Apr 2024 11:39:57 -0700 Subject: [PATCH 1/7] add collect support --- .../dask_cudf/dask_cudf/expr/_collection.py | 32 +++++++++++++++++++ python/dask_cudf/dask_cudf/expr/_groupby.py | 24 ++++++++++++++ .../dask_cudf/dask_cudf/tests/test_groupby.py | 26 +++++++-------- 3 files changed, 68 insertions(+), 14 deletions(-) diff --git a/python/dask_cudf/dask_cudf/expr/_collection.py b/python/dask_cudf/dask_cudf/expr/_collection.py index 516e35a4335..1a8b20ae326 100644 --- a/python/dask_cudf/dask_cudf/expr/_collection.py +++ b/python/dask_cudf/dask_cudf/expr/_collection.py @@ -17,6 +17,15 @@ import cudf +_LEGACY_WORKAROUND = ( + "To disable query planning, set the global " + "'dataframe.query-planning' config to `False` " + "before dask is imported. This can also be done " + "by setting an environment variable: " + "`DASK_DATAFRAME__QUERY_PLANNING=False` " +) + + ## ## Custom collection classes ## @@ -62,6 +71,7 @@ def groupby( sort=None, observed=None, dropna=None, + as_index=True, **kwargs, ): from dask_cudf.expr._groupby import GroupBy @@ -71,6 +81,13 @@ def groupby( f"`by` must be a column name or list of columns, got {by}." ) + if as_index is not True: + raise NotImplementedError( + f"`as_index` is not supported by dask-expr. Please disable " + "query planning, or reset the index after aggregating.\n" + f"{_LEGACY_WORKAROUND}" + ) + return GroupBy( self, by, @@ -151,3 +168,18 @@ def get_collection_type_csr_matrix(_): # Older version of dask-expr. # Implicit conversion to array wont work. pass + + +## +## Helper functions +## + + +def _legacy_instructions(): + return ( + "To disable query-planning, set the global " + "'dataframe.query-planning' config to `False` " + "before importing dask. This can also be done " + "by setting an environment variable: " + "`DASK_DATAFRAME__QUERY_PLANNING=False` " + ) diff --git a/python/dask_cudf/dask_cudf/expr/_groupby.py b/python/dask_cudf/dask_cudf/expr/_groupby.py index 7f275151f75..48528263a33 100644 --- a/python/dask_cudf/dask_cudf/expr/_groupby.py +++ b/python/dask_cudf/dask_cudf/expr/_groupby.py @@ -3,6 +3,7 @@ from dask_expr._groupby import ( GroupBy as DXGroupBy, SeriesGroupBy as DXSeriesGroupBy, + SingleAggregation, ) from dask_expr._util import is_scalar @@ -10,6 +11,23 @@ ## Custom groupby classes ## + +class Collect(SingleAggregation): + @staticmethod + def groupby_chunk(*args, **kwargs): + return args[0].agg("collect") + + @staticmethod + def groupby_aggregate(*args, **kwargs): + gb = args[0].agg("collect") + if gb.ndim > 1: + for col in gb.columns: + gb[col] = gb[col].list.concat() + return gb + else: + return gb.list.concat() + + # TODO: These classes are mostly a work-around for missing # `observed=False` support. # See: https://github.com/rapidsai/cudf/issues/15173 @@ -41,8 +59,14 @@ def __getitem__(self, key): ) return g + def collect(self, **kwargs): + return self._single_agg(Collect, **kwargs) + class SeriesGroupBy(DXSeriesGroupBy): def __init__(self, *args, observed=None, **kwargs): observed = observed if observed is not None else True super().__init__(*args, observed=observed, **kwargs) + + def collect(self, **kwargs): + return self._single_agg(Collect, **kwargs) diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index 3bb3e3b0bb8..bd455d35bee 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -14,16 +14,6 @@ from dask_cudf.groupby import OPTIMIZED_AGGS, _aggs_optimized from dask_cudf.tests.utils import QUERY_PLANNING_ON, xfail_dask_expr -# XFAIL "collect" tests for now -agg_params = [agg for agg in OPTIMIZED_AGGS if agg != "collect"] -if QUERY_PLANNING_ON: - agg_params.append( - # TODO: "collect" not supported with dask-expr yet - pytest.param("collect", marks=pytest.mark.xfail) - ) -else: - agg_params.append("collect") - def assert_cudf_groupby_layers(ddf): for prefix in ("cudf-aggregate-chunk", "cudf-aggregate-agg"): @@ -57,7 +47,7 @@ def pdf(request): return pdf -@pytest.mark.parametrize("aggregation", agg_params) +@pytest.mark.parametrize("aggregation", OPTIMIZED_AGGS) @pytest.mark.parametrize("series", [False, True]) def test_groupby_basic(series, aggregation, pdf): gdf = cudf.DataFrame.from_pandas(pdf) @@ -110,7 +100,7 @@ def test_groupby_cumulative(aggregation, pdf, series): dd.assert_eq(a, b) -@pytest.mark.parametrize("aggregation", agg_params) +@pytest.mark.parametrize("aggregation", OPTIMIZED_AGGS) @pytest.mark.parametrize( "func", [ @@ -579,8 +569,16 @@ def test_groupby_categorical_key(): dd.assert_eq(expect, got) -@xfail_dask_expr("as_index not supported in dask-expr") -@pytest.mark.parametrize("as_index", [True, False]) +@pytest.mark.parametrize( + "as_index", + [ + True, + pytest.param( + False, + marks=xfail_dask_expr("as_index not supported in dask-expr"), + ), + ], +) @pytest.mark.parametrize("split_out", ["use_dask_default", 1, 2]) @pytest.mark.parametrize("split_every", [False, 4]) @pytest.mark.parametrize("npartitions", [1, 10]) From 4fc20257b017fe647f52992f6aabd97ee78fc1a7 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Fri, 26 Apr 2024 07:22:45 -0700 Subject: [PATCH 2/7] remove dependency on upstream PR --- python/dask_cudf/dask_cudf/expr/_groupby.py | 30 +++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/python/dask_cudf/dask_cudf/expr/_groupby.py b/python/dask_cudf/dask_cudf/expr/_groupby.py index 48528263a33..b960dd5eb62 100644 --- a/python/dask_cudf/dask_cudf/expr/_groupby.py +++ b/python/dask_cudf/dask_cudf/expr/_groupby.py @@ -7,6 +7,8 @@ ) from dask_expr._util import is_scalar +from dask.dataframe.groupby import Aggregation + ## ## Custom groupby classes ## @@ -28,6 +30,28 @@ def groupby_aggregate(*args, **kwargs): return gb.list.concat() +collect_aggregation = Aggregation( + name="collect", + chunk=Collect.groupby_chunk, + agg=Collect.groupby_aggregate, +) + + +def _translate_arg(arg): + # Helper function to translate args so that + # they cab be processed correctly by upstream + # dask & dask-expr. Right now, the only necessary + # translation is list ("collect") aggregations. + if isinstance(arg, dict): + return {k: _translate_arg(v) for k, v in arg.items()} + elif isinstance(arg, list): + return [_translate_arg(x) for x in arg] + elif arg in ("collect", "list", list): + return collect_aggregation + else: + return arg + + # TODO: These classes are mostly a work-around for missing # `observed=False` support. # See: https://github.com/rapidsai/cudf/issues/15173 @@ -62,6 +86,9 @@ def __getitem__(self, key): def collect(self, **kwargs): return self._single_agg(Collect, **kwargs) + def aggregate(self, arg, **kwargs): + return super().aggregate(_translate_arg(arg), **kwargs) + class SeriesGroupBy(DXSeriesGroupBy): def __init__(self, *args, observed=None, **kwargs): @@ -70,3 +97,6 @@ def __init__(self, *args, observed=None, **kwargs): def collect(self, **kwargs): return self._single_agg(Collect, **kwargs) + + def aggregate(self, arg, **kwargs): + return super().aggregate(_translate_arg(arg), **kwargs) From 61015caeb8991a8d4b6bbe69a67ebfff723c6cf1 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Fri, 26 Apr 2024 07:32:43 -0700 Subject: [PATCH 3/7] remove stale code --- python/dask_cudf/dask_cudf/expr/_collection.py | 15 --------------- python/dask_cudf/dask_cudf/expr/_groupby.py | 4 ++-- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/python/dask_cudf/dask_cudf/expr/_collection.py b/python/dask_cudf/dask_cudf/expr/_collection.py index 1a8b20ae326..2b3285531a4 100644 --- a/python/dask_cudf/dask_cudf/expr/_collection.py +++ b/python/dask_cudf/dask_cudf/expr/_collection.py @@ -168,18 +168,3 @@ def get_collection_type_csr_matrix(_): # Older version of dask-expr. # Implicit conversion to array wont work. pass - - -## -## Helper functions -## - - -def _legacy_instructions(): - return ( - "To disable query-planning, set the global " - "'dataframe.query-planning' config to `False` " - "before importing dask. This can also be done " - "by setting an environment variable: " - "`DASK_DATAFRAME__QUERY_PLANNING=False` " - ) diff --git a/python/dask_cudf/dask_cudf/expr/_groupby.py b/python/dask_cudf/dask_cudf/expr/_groupby.py index b960dd5eb62..618e8aeb6fe 100644 --- a/python/dask_cudf/dask_cudf/expr/_groupby.py +++ b/python/dask_cudf/dask_cudf/expr/_groupby.py @@ -39,9 +39,9 @@ def groupby_aggregate(*args, **kwargs): def _translate_arg(arg): # Helper function to translate args so that - # they cab be processed correctly by upstream + # they can be processed correctly by upstream # dask & dask-expr. Right now, the only necessary - # translation is list ("collect") aggregations. + # translation is "collect" aggregations. if isinstance(arg, dict): return {k: _translate_arg(v) for k, v in arg.items()} elif isinstance(arg, list): From 5c700b2452f65fbf4810af37ce4d3bbb34130e75 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 29 Apr 2024 12:02:27 -0700 Subject: [PATCH 4/7] add warning - need to revise tests --- python/dask_cudf/dask_cudf/expr/_collection.py | 13 ++++++++++--- python/dask_cudf/dask_cudf/expr/_groupby.py | 8 ++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/python/dask_cudf/dask_cudf/expr/_collection.py b/python/dask_cudf/dask_cudf/expr/_collection.py index 2b3285531a4..c8cb479fdcd 100644 --- a/python/dask_cudf/dask_cudf/expr/_collection.py +++ b/python/dask_cudf/dask_cudf/expr/_collection.py @@ -1,5 +1,6 @@ # Copyright (c) 2024, NVIDIA CORPORATION. +import warnings from functools import cached_property from dask_expr import ( @@ -71,7 +72,6 @@ def groupby( sort=None, observed=None, dropna=None, - as_index=True, **kwargs, ): from dask_cudf.expr._groupby import GroupBy @@ -81,9 +81,16 @@ def groupby( f"`by` must be a column name or list of columns, got {by}." ) - if as_index is not True: + if "as_index" in kwargs: + warnings.warn( + "The `as_index` argument is no longer supported in " + "dask-cudf when query-planning is enabled.", + FutureWarning, + ) + + if kwargs.pop("as_index", True) is not True: raise NotImplementedError( - f"`as_index` is not supported by dask-expr. Please disable " + f"`as_index=False` is not supported. Please disable " "query planning, or reset the index after aggregating.\n" f"{_LEGACY_WORKAROUND}" ) diff --git a/python/dask_cudf/dask_cudf/expr/_groupby.py b/python/dask_cudf/dask_cudf/expr/_groupby.py index 618e8aeb6fe..116893891e3 100644 --- a/python/dask_cudf/dask_cudf/expr/_groupby.py +++ b/python/dask_cudf/dask_cudf/expr/_groupby.py @@ -16,12 +16,12 @@ class Collect(SingleAggregation): @staticmethod - def groupby_chunk(*args, **kwargs): - return args[0].agg("collect") + def groupby_chunk(arg): + return arg.agg("collect") @staticmethod - def groupby_aggregate(*args, **kwargs): - gb = args[0].agg("collect") + def groupby_aggregate(arg): + gb = arg.agg("collect") if gb.ndim > 1: for col in gb.columns: gb[col] = gb[col].list.concat() From 2170e18c775a2901dd8748b655f4fb917c4df28e Mon Sep 17 00:00:00 2001 From: rjzamora Date: Wed, 1 May 2024 09:58:13 -0700 Subject: [PATCH 5/7] clean up the message a bit --- .../dask_cudf/dask_cudf/expr/_collection.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/python/dask_cudf/dask_cudf/expr/_collection.py b/python/dask_cudf/dask_cudf/expr/_collection.py index 56f3152c653..c60b4e6a1f7 100644 --- a/python/dask_cudf/dask_cudf/expr/_collection.py +++ b/python/dask_cudf/dask_cudf/expr/_collection.py @@ -19,10 +19,10 @@ import cudf _LEGACY_WORKAROUND = ( - "To disable query planning, set the global " - "'dataframe.query-planning' config to `False` " - "before dask is imported. This can also be done " - "by setting an environment variable: " + "To enable the 'legacy' dask-cudf API, set the " + "global 'dataframe.query-planning' config to " + "`False` before dask is imported. This can also " + "be done by setting an environment variable: " "`DASK_DATAFRAME__QUERY_PLANNING=False` " ) @@ -98,18 +98,19 @@ def groupby( ) if "as_index" in kwargs: - warnings.warn( - "The `as_index` argument is no longer supported in " - "dask-cudf when query-planning is enabled.", - FutureWarning, + msg = ( + "The `as_index` argument is now deprecated. All groupby " + "results will be consistent with `as_index=True`.", ) - if kwargs.pop("as_index", True) is not True: - raise NotImplementedError( - f"`as_index=False` is not supported. Please disable " - "query planning, or reset the index after aggregating.\n" - f"{_LEGACY_WORKAROUND}" - ) + if kwargs.pop("as_index") is not True: + raise NotImplementedError( + f"{msg} Please reset the index after aggregating, or " + "use the legacy API if `as_index=False` is required.\n" + f"{_LEGACY_WORKAROUND}" + ) + else: + warnings.warn(msg, FutureWarning) return GroupBy( self, From 42b35c9cd4a8725449f8ee7ec89cf16425a0f1a2 Mon Sep 17 00:00:00 2001 From: "Richard (Rick) Zamora" Date: Wed, 1 May 2024 12:44:55 -0500 Subject: [PATCH 6/7] Update python/dask_cudf/dask_cudf/expr/_collection.py --- python/dask_cudf/dask_cudf/expr/_collection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/dask_cudf/dask_cudf/expr/_collection.py b/python/dask_cudf/dask_cudf/expr/_collection.py index c60b4e6a1f7..d50dfb24256 100644 --- a/python/dask_cudf/dask_cudf/expr/_collection.py +++ b/python/dask_cudf/dask_cudf/expr/_collection.py @@ -100,7 +100,7 @@ def groupby( if "as_index" in kwargs: msg = ( "The `as_index` argument is now deprecated. All groupby " - "results will be consistent with `as_index=True`.", + "results will be consistent with `as_index=True`." ) if kwargs.pop("as_index") is not True: From ca303be6ec55047e3d7ef400906e4d408bb21be5 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Wed, 1 May 2024 13:08:27 -0700 Subject: [PATCH 7/7] avoid breaking 15634 --- python/dask_cudf/dask_cudf/tests/test_groupby.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index 703253fff7a..67fa045d3d0 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -601,10 +601,19 @@ def test_groupby_agg_params(npartitions, split_every, split_out, as_index): if split_out == "use_dask_default": split_kwargs.pop("split_out") + # Avoid using as_index when query-planning is enabled + if QUERY_PLANNING_ON: + with pytest.warns(FutureWarning, match="argument is now deprecated"): + # Should warn when `as_index` is used + ddf.groupby(["name", "a"], sort=False, as_index=as_index) + maybe_as_index = {"as_index": as_index} if as_index is False else {} + else: + maybe_as_index = {"as_index": as_index} + # Check `sort=True` behavior if split_out == 1: gf = ( - ddf.groupby(["name", "a"], sort=True, as_index=as_index) + ddf.groupby(["name", "a"], sort=True, **maybe_as_index) .aggregate( agg_dict, **split_kwargs, @@ -626,7 +635,7 @@ def test_groupby_agg_params(npartitions, split_every, split_out, as_index): ) # Full check (`sort=False`) - gr = ddf.groupby(["name", "a"], sort=False, as_index=as_index).aggregate( + gr = ddf.groupby(["name", "a"], sort=False, **maybe_as_index).aggregate( agg_dict, **split_kwargs, )