diff --git a/cpp/src/dbscan/mergelabels/tree_reduction.cuh b/cpp/src/dbscan/mergelabels/tree_reduction.cuh index dd2630b01f..061657aace 100644 --- a/cpp/src/dbscan/mergelabels/tree_reduction.cuh +++ b/cpp/src/dbscan/mergelabels/tree_reduction.cuh @@ -39,12 +39,13 @@ void tree_reduction(const raft::handle_t& handle, const auto& comm = handle.get_comms(); int my_rank = comm.get_rank(); int n_rank = comm.get_size(); - raft::comms::request_t request; int s = 1; while (s < n_rank) { CUML_LOG_DEBUG("Tree reduction, s=", s); + raft::comms::request_t request; + // Find out whether the node is a receiver / sender / passive bool receiver = my_rank % (2 * s) == 0 && my_rank + s < n_rank; bool sender = my_rank % (2 * s) == s; @@ -57,7 +58,7 @@ void tree_reduction(const raft::handle_t& handle, comm.isend(labels, N, my_rank - s, 0, &request); } - comm.waitall(1, &request); + if (receiver || sender) { comm.waitall(1, &request); } if (receiver) { CUML_LOG_DEBUG("--> Merge labels"); diff --git a/python/cuml/cuml/dask/cluster/dbscan.py b/python/cuml/cuml/dask/cluster/dbscan.py index 2fd9192ee2..09cbe42b67 100644 --- a/python/cuml/cuml/dask/cluster/dbscan.py +++ b/python/cuml/cuml/dask/cluster/dbscan.py @@ -104,8 +104,12 @@ def fit(self, X, out_dtype="int32"): data = self.client.scatter(X, broadcast=True) - comms = Comms(comms_p2p=True) - comms.init() + # Get the workers that actually hold the scattered data + who_has = self.client.who_has(data) + workers = list(who_has[data.key]) + + comms = Comms(comms_p2p=True, client=self.client) + comms.init(workers=workers) # Get worker info to map workers to their RAFT ranks worker_info = comms.worker_info(comms.worker_addresses)