Skip to content

Commit cd749a3

Browse files
committed
Fix regression for regression
rapidsai#6614 broke BaseRandomForestModel for the "regressoin" op type. In this case, the shape we provided Dask was wrong, which eventually caused errors in `dask.array.concatenate` trying to convert cupy arrays to ndarrays. I'm not sure why CI didn't catch this. Perhaps older versions of dask weren't susceptible to this issue.
1 parent ac0e51c commit cd749a3

1 file changed

Lines changed: 12 additions & 10 deletions

File tree

  • python/cuml/cuml/dask/ensemble

python/cuml/cuml/dask/ensemble/base.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,14 @@ def _partial_inference(self, X, op_type, delayed, **kwargs):
199199
data = DistributedDataHandler.create(X, client=self.client)
200200
combined_data = list(map(lambda x: x[1], data.gpu_futures))
201201

202-
func = (
203-
_func_predict_partial
204-
if op_type == "regression"
205-
else _func_predict_proba_partial
206-
)
202+
if op_type == "classification":
203+
func = _func_predict_proba_partial
204+
shape = (X.shape[0], 1, self.num_classes)
205+
else:
206+
shape = (X.shape[0], 1)
207+
func = _func_predict_partial
208+
209+
meta = cp.zeros((0,) * len(shape), dtype=cp.float32)
207210

208211
partial_infs = list()
209212
for worker in self.active_workers:
@@ -217,14 +220,13 @@ def _partial_inference(self, X, op_type, delayed, **kwargs):
217220
pure=False,
218221
)
219222
)
220-
shape = (X.shape[0], 1, self.num_classes)
223+
224+
meta = cp.zeros((0,) * 3, dtype=cp.float32)
221225
objs = [
222-
dask.array.from_delayed(partial_inf, shape=shape, dtype=np.float32)
226+
dask.array.from_delayed(partial_inf, shape=shape, meta=meta)
223227
for partial_inf in partial_infs
224228
]
225-
result = dask.array.concatenate(
226-
objs, axis=1, allow_unknown_chunksizes=True
227-
)
229+
result = dask.array.concatenate(objs, axis=1)
228230
return result
229231

230232
def _predict_using_fil(self, X, delayed, **kwargs):

0 commit comments

Comments
 (0)