Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions cpp/src/cluster/detail/agglomerative.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@

#pragma once

#include <raft/core/copy.cuh>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/matrix/init.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

Expand Down Expand Up @@ -108,9 +112,15 @@ void build_dendrogram_host(raft::resources const& handle,
std::vector<value_idx> mst_dst_h(n_edges);
std::vector<value_t> mst_weights_h(n_edges);

raft::update_host(mst_src_h.data(), rows, n_edges, stream);
raft::update_host(mst_dst_h.data(), cols, n_edges, stream);
raft::update_host(mst_weights_h.data(), data, n_edges, stream);
raft::copy(handle,
raft::make_host_vector_view(mst_src_h.data(), n_edges),
raft::make_device_vector_view(rows, n_edges));
raft::copy(handle,
raft::make_host_vector_view(mst_dst_h.data(), n_edges),
raft::make_device_vector_view(cols, n_edges));
raft::copy(handle,
raft::make_host_vector_view(mst_weights_h.data(), n_edges),
raft::make_device_vector_view(data, n_edges));

raft::resource::sync_stream(handle, stream);

Expand Down Expand Up @@ -138,9 +148,15 @@ void build_dendrogram_host(raft::resources const& handle,
U.perform_union(aa, bb);
}

raft::update_device(children, children_h.data(), n_edges * 2, stream);
raft::update_device(out_size, out_size_h.data(), n_edges, stream);
raft::update_device(out_delta, out_delta_h.data(), n_edges, stream);
raft::copy(handle,
raft::make_device_vector_view(children, n_edges * 2),
raft::make_host_vector_view(children_h.data(), n_edges * 2));
raft::copy(handle,
raft::make_device_vector_view(out_size, n_edges),
raft::make_host_vector_view(out_size_h.data(), n_edges));
raft::copy(handle,
raft::make_device_vector_view(out_delta, n_edges),
raft::make_host_vector_view(out_delta_h.data(), n_edges));
}

template <typename value_idx>
Expand Down Expand Up @@ -236,7 +252,8 @@ void extract_flattened_clusters(raft::resources const& handle,

// Handle special case where n_clusters == 1
if (n_clusters == 1) {
thrust::fill(thrust_policy, labels, labels + n_leaves, 0);
raft::matrix::fill(
handle, raft::make_device_vector_view<value_idx>(labels, n_leaves), value_idx(0));
} else {
/**
* Compute levels for each node
Expand Down
7 changes: 5 additions & 2 deletions cpp/src/cluster/detail/connectivities.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2021-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -9,6 +9,9 @@
#include "./kmeans_common.cuh"
#include <cuvs/cluster/agglomerative.hpp>
#include <cuvs/distance/distance.hpp>
#include <raft/core/copy.cuh>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
Expand Down Expand Up @@ -144,7 +147,7 @@ void pairwise_distances(const raft::resources& handle,
raft::make_device_vector_view<value_idx, value_idx>(indptr, m),
[=] __device__(value_idx idx) { return idx * m; });

raft::update_device(indptr + m, &nnz, 1, stream);
raft::copy(handle, raft::make_device_scalar_view(indptr + m), raft::make_host_scalar_view(&nnz));

// TODO: It would ultimately be nice if the MST could accept
// dense inputs directly so we don't need to double the memory
Expand Down
145 changes: 75 additions & 70 deletions cpp/src/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
#include <cuvs/cluster/kmeans.hpp>
#include <cuvs/distance/distance.hpp>

#include <raft/core/copy.cuh>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/kvp.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/mdarray.hpp>
Expand All @@ -20,12 +23,15 @@
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/map.cuh>
#include <raft/linalg/map_then_reduce.cuh>
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/reduce.cuh>
#include <raft/linalg/reduce_cols_by_key.cuh>
#include <raft/linalg/reduce_rows_by_key.cuh>
#include <raft/matrix/gather.cuh>
#include <raft/matrix/init.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>
Expand Down Expand Up @@ -133,8 +139,7 @@ void kmeansPlusPlus(raft::resources const& handle,

if (metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(
L2NormX.data_handle(), X.data_handle(), X.extent(1), X.extent(0), stream);
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(handle, X, L2NormX.view());
}

raft::random::RngState rng(params.rng_state.seed, params.rng_state.type);
Expand All @@ -147,8 +152,9 @@ void kmeansPlusPlus(raft::resources const& handle,
int n_clusters_picked = 1;

// store the chosen centroid in the buffer
raft::copy(
centroidsRawData.data_handle(), initialCentroid.data_handle(), initialCentroid.size(), stream);
raft::copy(handle,
raft::make_device_vector_view(centroidsRawData.data_handle(), initialCentroid.size()),
raft::make_device_vector_view(initialCentroid.data_handle(), initialCentroid.size()));

// C = initial set of centroids
auto centroids = raft::make_device_matrix_view<DataT, IndexT>(
Expand Down Expand Up @@ -198,12 +204,13 @@ void kmeansPlusPlus(raft::resources const& handle,

// Calculate costPerCandidate[n_trials] where costPerCandidate[i] is the cluster cost when using
// centroid candidate-i
raft::linalg::reduce<true, true>(costPerCandidate.data_handle(),
minDistBuf.data_handle(),
minDistBuf.extent(1),
minDistBuf.extent(0),
static_cast<DataT>(0),
stream);
raft::linalg::reduce<raft::Apply::ALONG_ROWS>(
handle,
raft::make_device_matrix_view<const DataT, IndexT, raft::row_major>(
minDistBuf.data_handle(), minDistBuf.extent(0), minDistBuf.extent(1)),
raft::make_device_vector_view<DataT, IndexT>(costPerCandidate.data_handle(),
minDistBuf.extent(0)),
static_cast<DataT>(0));

// Greedy Choice - Choose the candidate that has minimum cluster cost
// ArgMin operation below identifies the index of minimum cost in costPerCandidate
Expand All @@ -229,21 +236,24 @@ void kmeansPlusPlus(raft::resources const& handle,
stream);

int bestCandidateIdx = -1;
raft::copy(&bestCandidateIdx, &minClusterIndexAndDistance.data()->key, 1, stream);
raft::copy(handle,
raft::make_host_scalar_view(&bestCandidateIdx),
raft::make_device_scalar_view(&minClusterIndexAndDistance.data()->key));
raft::resource::sync_stream(handle);
/// <<< End of Step-3 >>>

/// <<< Step-4 >>>: C = C U {x}
// Update minimum cluster distance corresponding to the chosen centroid candidate
raft::copy(minClusterDistance.data_handle(),
minDistBuf.data_handle() + bestCandidateIdx * n_samples,
n_samples,
stream);
raft::copy(handle,
raft::make_device_vector_view(minClusterDistance.data_handle(), n_samples),
raft::make_device_vector_view(
minDistBuf.data_handle() + bestCandidateIdx * n_samples, n_samples));

raft::copy(centroidsRawData.data_handle() + n_clusters_picked * n_features,
centroidCandidates.data_handle() + bestCandidateIdx * n_features,
n_features,
stream);
raft::copy(handle,
raft::make_device_vector_view(
centroidsRawData.data_handle() + n_clusters_picked * n_features, n_features),
raft::make_device_vector_view(
centroidCandidates.data_handle() + bestCandidateIdx * n_features, n_features));

++n_clusters_picked;
/// <<< End of Step-4 >>>
Expand Down Expand Up @@ -383,8 +393,7 @@ void kmeans_fit_main(raft::resources const& handle,

if (metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(
L2NormX.data_handle(), X.data_handle(), X.extent(1), X.extent(0), stream);
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(handle, X, L2NormX.view());
}

RAFT_LOG_DEBUG(
Expand Down Expand Up @@ -448,10 +457,11 @@ void kmeans_fit_main(raft::resources const& handle,
newCentroids.data_handle());

DataT sqrdNormError = 0;
raft::copy(&sqrdNormError, sqrdNorm.data_handle(), sqrdNorm.size(), stream);
raft::copy(handle, raft::make_host_scalar_view(&sqrdNormError), sqrdNorm.view());

raft::copy(
centroidsRawData.data_handle(), newCentroids.data_handle(), newCentroids.size(), stream);
raft::copy(handle,
raft::make_device_vector_view(centroidsRawData.data_handle(), newCentroids.size()),
raft::make_device_vector_view(newCentroids.data_handle(), newCentroids.size()));

bool done = false;
if (params.inertia_check) {
Expand Down Expand Up @@ -501,12 +511,10 @@ void kmeans_fit_main(raft::resources const& handle,
params.batch_centroids,
workspace);

// TODO: add different templates for InType of binaryOp to avoid thrust transform
thrust::transform(raft::resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
weight.data_handle(),
minClusterAndDistance.data_handle(),
raft::linalg::map(handle,
raft::make_const_mdspan(minClusterAndDistance.view()),
weight,
Comment thread
achirkin marked this conversation as resolved.
Outdated
minClusterAndDistance.view(),
[=] __device__(const raft::KeyValuePair<IndexT, DataT> kvp, DataT wt) {
raft::KeyValuePair<IndexT, DataT> res;
res.value = kvp.value * wt;
Expand Down Expand Up @@ -586,13 +594,16 @@ void initScalableKMeansPlusPlus(raft::resources const& handle,
// device buffer to flag the sample that is chosen as initial centroid
auto isSampleCentroid = raft::make_device_vector<uint8_t, IndexT>(handle, n_samples);

raft::copy(
isSampleCentroid.data_handle(), h_isSampleCentroid.data(), isSampleCentroid.size(), stream);
raft::copy(handle,
raft::make_device_vector_view(isSampleCentroid.data_handle(), isSampleCentroid.size()),
raft::make_host_vector_view(h_isSampleCentroid.data(), isSampleCentroid.size()));

rmm::device_uvector<DataT> centroidsBuf(initialCentroid.size(), stream);

// reset buffer to store the chosen centroid
raft::copy(centroidsBuf.data(), initialCentroid.data_handle(), initialCentroid.size(), stream);
raft::copy(handle,
raft::make_device_vector_view(centroidsBuf.data(), initialCentroid.size()),
raft::make_device_vector_view(initialCentroid.data_handle(), initialCentroid.size()));

auto potentialCentroids = raft::make_device_matrix_view<DataT, IndexT>(
centroidsBuf.data(), initialCentroid.extent(0), initialCentroid.extent(1));
Expand All @@ -606,8 +617,7 @@ void initScalableKMeansPlusPlus(raft::resources const& handle,
auto L2NormX = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
if (metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(
L2NormX.data_handle(), X.data_handle(), X.extent(1), X.extent(0), stream);
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(handle, X, L2NormX.view());
}

auto minClusterDistanceVec = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
Expand Down Expand Up @@ -700,8 +710,10 @@ void initScalableKMeansPlusPlus(raft::resources const& handle,
/// <<<< Step-5 >>> : C = C U C'
// append the data in Cp to the buffer holding the potentialCentroids
centroidsBuf.resize(centroidsBuf.size() + Cp.size(), stream);
raft::copy(
centroidsBuf.data() + centroidsBuf.size() - Cp.size(), Cp.data_handle(), Cp.size(), stream);
raft::copy(handle,
raft::make_device_vector_view(centroidsBuf.data() + centroidsBuf.size() - Cp.size(),
Cp.size()),
raft::make_device_vector_view(Cp.data_handle(), Cp.size()));

IndexT tot_centroids = potentialCentroids.extent(0) + Cp.extent(0);
potentialCentroids =
Expand Down Expand Up @@ -760,16 +772,17 @@ void initScalableKMeansPlusPlus(raft::resources const& handle,
initRandom<DataT, IndexT>(handle, rand_params, X, centroidsRawData);

// copy centroids generated during kmeans|| iteration to the buffer
raft::copy(centroidsRawData.data_handle() + n_random_clusters * n_features,
potentialCentroids.data_handle(),
potentialCentroids.size(),
stream);
raft::copy(
handle,
raft::make_device_vector_view(centroidsRawData.data_handle() + n_random_clusters * n_features,
potentialCentroids.size()),
raft::make_device_vector_view(potentialCentroids.data_handle(), potentialCentroids.size()));
} else {
// found the required n_clusters
raft::copy(centroidsRawData.data_handle(),
potentialCentroids.data_handle(),
potentialCentroids.size(),
stream);
raft::copy(
handle,
raft::make_device_vector_view(centroidsRawData.data_handle(), potentialCentroids.size()),
raft::make_device_vector_view(potentialCentroids.data_handle(), potentialCentroids.size()));
}
}

Expand Down Expand Up @@ -850,12 +863,9 @@ void kmeans_fit(raft::resources const& handle,
rmm::device_uvector<char> workspace(0, stream);
auto weight = raft::make_device_vector<DataT>(handle, n_samples);
if (sample_weight.has_value())
raft::copy(weight.data_handle(), sample_weight.value().data_handle(), n_samples, stream);
raft::copy(handle, weight.view(), sample_weight.value());
else
thrust::fill(raft::resource::get_thrust_policy(handle),
weight.data_handle(),
weight.data_handle() + weight.size(),
1);
raft::matrix::fill(handle, weight.view(), DataT(1));

// check if weights sum up to n_samples
checkWeight<DataT>(handle, weight.view(), workspace);
Expand Down Expand Up @@ -910,7 +920,9 @@ void kmeans_fit(raft::resources const& handle,
seed_iter + 1,
n_init);
raft::copy(
centroidsRawData.data_handle(), centroids.data_handle(), n_clusters * n_features, stream);
handle,
raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features),
raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features));
} else {
THROW("unknown initialization method to select initial centers");
}
Expand All @@ -928,7 +940,9 @@ void kmeans_fit(raft::resources const& handle,
inertia[0] = iter_inertia;
n_iter[0] = n_current_iter;
raft::copy(
centroids.data_handle(), centroidsRawData.data_handle(), n_clusters * n_features, stream);
handle,
raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features),
raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features));
}
RAFT_LOG_DEBUG("KMeans.fit after iteration-%d/%d: inertia - %f, n_iter[0] - %d",
seed_iter + 1,
Expand Down Expand Up @@ -998,12 +1012,9 @@ void kmeans_predict(raft::resources const& handle,
rmm::device_uvector<char> workspace(0, stream);
auto weight = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
if (sample_weight.has_value())
raft::copy(weight.data_handle(), sample_weight.value().data_handle(), n_samples, stream);
raft::copy(handle, weight.view(), sample_weight.value());
else
thrust::fill(raft::resource::get_thrust_policy(handle),
weight.data_handle(),
weight.data_handle() + weight.size(),
1);
raft::matrix::fill(handle, weight.view(), DataT(1));

// check if weights sum up to n_samples
if (normalize_weight) checkWeight(handle, weight.view(), workspace);
Expand All @@ -1016,8 +1027,7 @@ void kmeans_predict(raft::resources const& handle,
auto L2NormX = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
if (metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(
L2NormX.data_handle(), X.data_handle(), X.extent(1), X.extent(0), stream);
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(handle, X, L2NormX.view());
}

// computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i]
Expand All @@ -1041,12 +1051,10 @@ void kmeans_predict(raft::resources const& handle,

// calculate cluster cost phi_x(C)
rmm::device_scalar<DataT> clusterCostD(stream);
// TODO: add different templates for InType of binaryOp to avoid thrust transform
thrust::transform(raft::resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
weight.data_handle(),
minClusterAndDistance.data_handle(),
raft::linalg::map(handle,
raft::make_const_mdspan(minClusterAndDistance.view()),
raft::make_const_mdspan(weight.view()),
minClusterAndDistance.view(),
[=] __device__(const raft::KeyValuePair<IndexT, DataT> kvp, DataT wt) {
raft::KeyValuePair<IndexT, DataT> res;
res.value = kvp.value * wt;
Expand All @@ -1062,11 +1070,8 @@ void kmeans_predict(raft::resources const& handle,
raft::value_op{},
raft::add_op{});

thrust::transform(raft::resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
labels.data_handle(),
raft::key_op{});
raft::linalg::map(
handle, raft::make_const_mdspan(minClusterAndDistance.view()), labels, raft::key_op{});

inertia[0] = clusterCostD.value(stream);
}
Expand Down
Loading