Skip to content

Commit 79109d4

Browse files
authored
dask-polars: use splat everywhere. (#18243)
In the dask-graph, use the splat operator instead of list arguments everywhere. The splat operator is used in most places already, only a few functions are missing. This makes it easier to transform the dask-graph after its creation, which will be useful for prototyping. Authors: - Mads R. B. Kristensen (https://github.com/madsbk) Approvers: - Lawrence Mitchell (https://github.com/wence-) URL: #18243
1 parent 3e080dd commit 79109d4

File tree

5 files changed

+11
-9
lines changed

5 files changed

+11
-9
lines changed

python/cudf_polars/cudf_polars/experimental/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from cudf_polars.dsl.ir import Union
1010

1111
if TYPE_CHECKING:
12-
from collections.abc import Iterator, Sequence
12+
from collections.abc import Iterator
1313

1414
from cudf_polars.containers import DataFrame
1515
from cudf_polars.dsl.expr import NamedExpr
@@ -44,6 +44,6 @@ def get_key_name(node: Node) -> str:
4444
return f"{type(node).__name__.lower()}-{hash(node)}"
4545

4646

47-
def _concat(dfs: Sequence[DataFrame]) -> DataFrame:
47+
def _concat(*dfs: DataFrame) -> DataFrame:
4848
# Concatenate a sequence of DataFrames vertically
4949
return Union.do_evaluate(None, *dfs)

python/cudf_polars/cudf_polars/experimental/groupby.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,8 @@ def _(
251251
return new_node, partition_info
252252

253253

254-
def _tree_node(do_evaluate, batch, *args):
255-
return do_evaluate(*args, _concat(batch))
254+
def _tree_node(do_evaluate, nbatch, *args):
255+
return do_evaluate(*args[nbatch:], _concat(*args[:nbatch]))
256256

257257

258258
@generate_ir_tasks.register(GroupBy)
@@ -289,7 +289,8 @@ def _(
289289
graph[(name, j, i)] = (
290290
_tree_node,
291291
ir.do_evaluate,
292-
batch,
292+
len(batch),
293+
*batch,
293294
*ir._non_child_args,
294295
)
295296
new_keys.append((name, j, i))
@@ -298,7 +299,8 @@ def _(
298299
graph[(name, 0)] = (
299300
_tree_node,
300301
ir.do_evaluate,
301-
keys,
302+
len(keys),
303+
*keys,
302304
*ir._non_child_args,
303305
)
304306
return graph

python/cudf_polars/cudf_polars/experimental/join.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,6 @@ def _(
309309
if len(_concat_list) == 1:
310310
graph[(out_name, part_out)] = graph.pop(_concat_list[0])
311311
else:
312-
graph[(out_name, part_out)] = (_concat, _concat_list)
312+
graph[(out_name, part_out)] = (_concat, *_concat_list)
313313

314314
return graph

python/cudf_polars/cudf_polars/experimental/parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def task_graph(
150150
key_name = get_key_name(ir)
151151
partition_count = partition_info[ir].count
152152
if partition_count > 1:
153-
graph[key_name] = (_concat, list(partition_info[ir].keys(ir)))
153+
graph[key_name] = (_concat, *partition_info[ir].keys(ir))
154154
return graph, key_name
155155
else:
156156
return graph, (key_name, 0)

python/cudf_polars/cudf_polars/experimental/shuffle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def _simple_shuffle_graph(
159159
(split_name, part_in),
160160
part_out,
161161
)
162-
graph[(name_out, part_out)] = (_concat, _concat_list)
162+
graph[(name_out, part_out)] = (_concat, *_concat_list)
163163
return graph
164164

165165

0 commit comments

Comments
 (0)