@@ -513,14 +513,14 @@ class cluster_loader {
513513 if (needs_copy_) {
514514 // For prefetching to overlap with other gpu work
515515 // we need to schedule copies on the provided copy stream stream_
516- auto stream = raft::resource::get_cuda_stream (res);
517- raft::resource::set_cuda_stream (res , stream_);
516+ auto copy_res = raft::resources (res);
517+ raft::resource::set_cuda_stream (copy_res , stream_);
518518
519519 // htod
520520 auto h_cluster_ids =
521521 raft::make_pinned_vector_view<LabelT, int64_t >(cluster_ids_buf_.data_handle (), size);
522522
523- raft::copy (res , h_cluster_ids, cluster_ids);
523+ raft::copy (copy_res , h_cluster_ids, cluster_ids);
524524 raft::resource::sync_stream (res, stream_);
525525
526526 auto pinned_cluster = raft::make_pinned_matrix_view<T, int64_t >(
@@ -534,11 +534,8 @@ class cluster_loader {
534534 sizeof (T) * dim_);
535535 }
536536
537- raft::copy (res , cluster_vectors, raft::make_const_mdspan (pinned_cluster));
537+ raft::copy (copy_res , cluster_vectors, raft::make_const_mdspan (pinned_cluster));
538538 raft::resource::sync_stream (res, stream_);
539-
540- // reset stream back to previous value
541- raft::resource::set_cuda_stream (res, stream);
542539 } else {
543540 // dtod
544541 auto dataset_view =
0 commit comments