Skip to content

Commit eda9bcc

Browse files
authored
__dask_distributed_pack__(): client argument (#4248)
1 parent 04a6b78 commit eda9bcc

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

distributed/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2581,7 +2581,7 @@ def _graph_to_futures(
25812581
if not isinstance(dsk, HighLevelGraph):
25822582
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
25832583

2584-
dsk = highlevelgraph_pack(dsk, keyset, self, self.futures)
2584+
dsk = highlevelgraph_pack(dsk, self, keyset)
25852585

25862586
if isinstance(retries, Number) and retries > 0:
25872587
retries = {k: retries for k in dsk}

distributed/protocol/highlevelgraph.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@ def _materialized_layer_pack(
2727
layer: Layer,
2828
all_keys,
2929
known_key_dependencies,
30+
client,
3031
client_keys,
31-
allowed_client,
32-
allowed_futures,
3332
):
3433
from ..client import Future
3534

@@ -47,11 +46,11 @@ def _materialized_layer_pack(
4746
dsk = {k: unpack_remotedata(v, byte_keys=True) for k, v in layer.items()}
4847
unpacked_futures = set.union(*[v[1] for v in dsk.values()]) if dsk else set()
4948
for future in unpacked_futures:
50-
if future.client is not allowed_client:
49+
if future.client is not client:
5150
raise ValueError(
5251
"Inputs contain futures that were created by another client."
5352
)
54-
if tokey(future.key) not in allowed_futures:
53+
if tokey(future.key) not in client.futures:
5554
raise CancelledError(tokey(future.key))
5655
unpacked_futures_deps = {}
5756
for k, v in dsk.items():
@@ -76,15 +75,13 @@ def _materialized_layer_pack(
7675
return {"dsk": dsk, "dependencies": dependencies}
7776

7877

79-
def highlevelgraph_pack(
80-
hlg: HighLevelGraph, client_keys, allowed_client, allowed_futures
81-
):
78+
def highlevelgraph_pack(hlg: HighLevelGraph, client, client_keys):
8279
layers = []
8380

8481
# Dump each layer (in topological order)
8582
for layer in (hlg.layers[name] for name in hlg._toposort_layers()):
8683
if not layer.is_materialized():
87-
state = layer.__dask_distributed_pack__()
84+
state = layer.__dask_distributed_pack__(client)
8885
if state is not None:
8986
layers.append(
9087
{
@@ -104,9 +101,8 @@ def highlevelgraph_pack(
104101
layer,
105102
hlg.get_all_external_keys(),
106103
hlg.key_dependencies,
104+
client,
107105
client_keys,
108-
allowed_client,
109-
allowed_futures,
110106
),
111107
}
112108
)

0 commit comments

Comments
 (0)