Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions python/dask_cudf/dask_cudf/expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@

import cudf

_LEGACY_WORKAROUND = (
"To disable query planning, set the global "
"'dataframe.query-planning' config to `False` "
"before dask is imported. This can also be done "
"by setting an environment variable: "
"`DASK_DATAFRAME__QUERY_PLANNING=False` "
)


##
## Custom collection classes
##
Expand Down Expand Up @@ -62,6 +71,7 @@ def groupby(
sort=None,
observed=None,
dropna=None,
as_index=True,
**kwargs,
):
from dask_cudf.expr._groupby import GroupBy
Expand All @@ -71,6 +81,13 @@ def groupby(
f"`by` must be a column name or list of columns, got {by}."
)

if as_index is not True:
raise NotImplementedError(
f"`as_index` is not supported by dask-expr. Please disable "
"query planning, or reset the index after aggregating.\n"
f"{_LEGACY_WORKAROUND}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This is a somewhat confusing error message. The only way to get past it with query planning enabled is to say as_index=True, but the error message seems to say "as_index=True` is not handled by dask-expr.

Do you mean:

dask-expr only supports as_index=True. For as_index=False either disable query planning or reset the index with reset_index after aggregating.

WDYT?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds like dask-expr doesn't actually have support for the as_index keyword arg in general and always follows the behavior of as_index=False, so perhaps we should consider:

  • checking if the kwarg is provided at all, emitting a FutureWarning if so
  • if the kwarg isn't as_index=True, raise the error description suggested above

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The upstream dask.dataframe API has always raised an error when as_index is used, but dask-cudf has used a distinct groupby API until now.

I agree that the real problem is that as_index is not supported at all by dask-expr. Therefore @charlesbluca's suggestion is probably the most "correct". With that said, I'm feeling a bit hesitant to add more noise for something that technically "works fine" :/


return GroupBy(
self,
by,
Expand Down
54 changes: 54 additions & 0 deletions python/dask_cudf/dask_cudf/expr/_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,55 @@
from dask_expr._groupby import (
GroupBy as DXGroupBy,
SeriesGroupBy as DXSeriesGroupBy,
SingleAggregation,
)
from dask_expr._util import is_scalar

from dask.dataframe.groupby import Aggregation

##
## Custom groupby classes
##


class Collect(SingleAggregation):
@staticmethod
def groupby_chunk(*args, **kwargs):
return args[0].agg("collect")

@staticmethod
def groupby_aggregate(*args, **kwargs):
gb = args[0].agg("collect")
if gb.ndim > 1:
for col in gb.columns:
gb[col] = gb[col].list.concat()
return gb
else:
return gb.list.concat()


collect_aggregation = Aggregation(
name="collect",
chunk=Collect.groupby_chunk,
agg=Collect.groupby_aggregate,
)


def _translate_arg(arg):
# Helper function to translate args so that
# they can be processed correctly by upstream
# dask & dask-expr. Right now, the only necessary
# translation is "collect" aggregations.
if isinstance(arg, dict):
return {k: _translate_arg(v) for k, v in arg.items()}
elif isinstance(arg, list):
return [_translate_arg(x) for x in arg]
elif arg in ("collect", "list", list):
return collect_aggregation
else:
return arg


# TODO: These classes are mostly a work-around for missing
# `observed=False` support.
# See: https://github.com/rapidsai/cudf/issues/15173
Expand Down Expand Up @@ -41,8 +83,20 @@ def __getitem__(self, key):
)
return g

def collect(self, **kwargs):
return self._single_agg(Collect, **kwargs)

def aggregate(self, arg, **kwargs):
return super().aggregate(_translate_arg(arg), **kwargs)


class SeriesGroupBy(DXSeriesGroupBy):
def __init__(self, *args, observed=None, **kwargs):
observed = observed if observed is not None else True
super().__init__(*args, observed=observed, **kwargs)

def collect(self, **kwargs):
return self._single_agg(Collect, **kwargs)

def aggregate(self, arg, **kwargs):
return super().aggregate(_translate_arg(arg), **kwargs)
26 changes: 12 additions & 14 deletions python/dask_cudf/dask_cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,6 @@
from dask_cudf.groupby import OPTIMIZED_AGGS, _aggs_optimized
from dask_cudf.tests.utils import QUERY_PLANNING_ON, xfail_dask_expr

# XFAIL "collect" tests for now
agg_params = [agg for agg in OPTIMIZED_AGGS if agg != "collect"]
if QUERY_PLANNING_ON:
agg_params.append(
# TODO: "collect" not supported with dask-expr yet
pytest.param("collect", marks=pytest.mark.xfail)
)
else:
agg_params.append("collect")


def assert_cudf_groupby_layers(ddf):
for prefix in ("cudf-aggregate-chunk", "cudf-aggregate-agg"):
Expand Down Expand Up @@ -57,7 +47,7 @@ def pdf(request):
return pdf


@pytest.mark.parametrize("aggregation", agg_params)
@pytest.mark.parametrize("aggregation", OPTIMIZED_AGGS)
@pytest.mark.parametrize("series", [False, True])
def test_groupby_basic(series, aggregation, pdf):
gdf = cudf.DataFrame.from_pandas(pdf)
Expand Down Expand Up @@ -110,7 +100,7 @@ def test_groupby_cumulative(aggregation, pdf, series):
dd.assert_eq(a, b)


@pytest.mark.parametrize("aggregation", agg_params)
@pytest.mark.parametrize("aggregation", OPTIMIZED_AGGS)
@pytest.mark.parametrize(
"func",
[
Expand Down Expand Up @@ -579,8 +569,16 @@ def test_groupby_categorical_key():
dd.assert_eq(expect, got)


@xfail_dask_expr("as_index not supported in dask-expr")
@pytest.mark.parametrize("as_index", [True, False])
@pytest.mark.parametrize(
"as_index",
[
True,
pytest.param(
False,
marks=xfail_dask_expr("as_index not supported in dask-expr"),
),
],
)
@pytest.mark.parametrize("split_out", ["use_dask_default", 1, 2])
@pytest.mark.parametrize("split_every", [False, 4])
@pytest.mark.parametrize("npartitions", [1, 10])
Expand Down