Skip to content

Commit 7d144cf

Browse files
authored
Migrate trustworthiness and silhouette_score stats from RAFT (#313)
Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #313
1 parent 9e8ec39 commit 7d144cf

14 files changed

Lines changed: 1837 additions & 25 deletions

cpp/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,8 @@ add_library(
424424
src/selection/select_k_float_int64_t.cu
425425
src/selection/select_k_float_uint32_t.cu
426426
src/selection/select_k_half_uint32_t.cu
427+
src/stats/silhouette_score.cu
428+
src/stats/trustworthiness_score.cu
427429
)
428430

429431
target_compile_definitions(cuvs PRIVATE "CUVS_EXPLICIT_INSTANTIATE_ONLY")
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright (c) 2024, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <cuvs/distance/distance.hpp>
19+
#include <raft/core/device_mdspan.hpp>
20+
#include <raft/core/resources.hpp>
21+
22+
namespace cuvs {
23+
namespace stats {
24+
25+
/**
26+
* @defgroup stats_silhouette_score Silhouette Score
27+
* @{
28+
*/
29+
/**
30+
* @brief main function that returns the average silhouette score for a given set of data and its
31+
* clusterings
32+
* @param[in] handle: raft handle for managing expensive resources
33+
* @param[in] X_in: input matrix Data in row-major format (nRows x nCols)
34+
* @param[in] labels: the pointer to the array containing labels for every data sample (length:
35+
* nRows)
36+
* @param[out] silhouette_score_per_sample: optional array populated with the silhouette score
37+
* for every sample (length: nRows)
38+
* @param[in] n_unique_labels: number of unique labels in the labels array
39+
* @param[in] metric: Distance metric to use. Euclidean (L2) is used by default
40+
* @return: The silhouette score.
41+
*/
42+
float silhouette_score(
43+
raft::resources const& handle,
44+
raft::device_matrix_view<const float, int64_t, raft::row_major> X_in,
45+
raft::device_vector_view<const int, int64_t> labels,
46+
std::optional<raft::device_vector_view<float, int64_t>> silhouette_score_per_sample,
47+
int64_t n_unique_labels,
48+
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded);
49+
50+
/**
51+
* @brief function that returns the average silhouette score for a given set of data and its
52+
* clusterings
53+
* @param[in] handle: raft handle for managing expensive resources
54+
* @param[in] X: input matrix Data in row-major format (nRows x nCols)
55+
* @param[in] labels: the pointer to the array containing labels for every data sample (length:
56+
* nRows)
57+
* @param[out] silhouette_score_per_sample: optional array populated with the silhouette score
58+
* for every sample (length: nRows)
59+
* @param[in] n_unique_labels: number of unique labels in the labels array
60+
* @param[in] batch_size: number of samples per batch
61+
* @param[in] metric: the numerical value that maps to the type of distance metric to be used in
62+
* the calculations
63+
* @return: The silhouette score.
64+
*/
65+
float silhouette_score_batched(
66+
raft::resources const& handle,
67+
raft::device_matrix_view<const float, int64_t, raft::row_major> X,
68+
raft::device_vector_view<const int, int64_t> labels,
69+
std::optional<raft::device_vector_view<float, int64_t>> silhouette_score_per_sample,
70+
int64_t n_unique_labels,
71+
int64_t batch_size,
72+
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded);
73+
74+
/**
75+
* @brief main function that returns the average silhouette score for a given set of data and its
76+
* clusterings
77+
* @param[in] handle: raft handle for managing expensive resources
78+
* @param[in] X_in: input matrix Data in row-major format (nRows x nCols)
79+
* @param[in] labels: the pointer to the array containing labels for every data sample (length:
80+
* nRows)
81+
* @param[out] silhouette_score_per_sample: optional array populated with the silhouette score
82+
* for every sample (length: nRows)
83+
* @param[in] n_unique_labels: number of unique labels in the labels array
84+
* @param[in] metric: the numerical value that maps to the type of distance metric to be used in
85+
* the calculations
86+
* @return: The silhouette score.
87+
*/
88+
double silhouette_score(
89+
raft::resources const& handle,
90+
raft::device_matrix_view<const double, int64_t, raft::row_major> X_in,
91+
raft::device_vector_view<const int, int64_t> labels,
92+
std::optional<raft::device_vector_view<double, int64_t>> silhouette_score_per_sample,
93+
int64_t n_unique_labels,
94+
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded);
95+
96+
/**
97+
* @brief function that returns the average silhouette score for a given set of data and its
98+
* clusterings
99+
* @param[in] handle: raft handle for managing expensive resources
100+
* @param[in] X: input matrix Data in row-major format (nRows x nCols)
101+
* @param[in] labels: the pointer to the array containing labels for every data sample (length:
102+
* nRows)
103+
* @param[out] silhouette_score_per_sample: optional array populated with the silhouette score
104+
* for every sample (length: nRows)
105+
* @param[in] n_unique_labels: number of unique labels in the labels array
106+
* @param[in] batch_size: number of samples per batch
107+
* @param[in] metric: the numerical value that maps to the type of distance metric to be used in
108+
* the calculations
109+
* @return: The silhouette score.
110+
*/
111+
double silhouette_score_batched(
112+
raft::resources const& handle,
113+
raft::device_matrix_view<const double, int64_t, raft::row_major> X,
114+
raft::device_vector_view<const int, int64_t> labels,
115+
std::optional<raft::device_vector_view<double, int64_t>> silhouette_score_per_sample,
116+
int64_t n_unique_labels,
117+
int64_t batch_size,
118+
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded);
119+
120+
} // namespace stats
121+
} // namespace cuvs
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) 2024, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
#include <cuvs/distance/distance.hpp>
19+
#include <raft/core/device_mdspan.hpp>
20+
#include <raft/core/resources.hpp>
21+
22+
namespace cuvs {
23+
namespace stats {
24+
/**
25+
* @defgroup stats_trustworthiness Trustworthiness
26+
* @{
27+
*/
28+
29+
/**
30+
* @brief Compute the trustworthiness score
31+
* @param[in] handle the raft handle
32+
* @param[in] X: Data in original dimension
33+
* @param[in] X_embedded: Data in target dimension (embedding)
34+
* @param[in] n_neighbors Number of neighbors considered by trustworthiness score
35+
* @param[in] metric Distance metric to use. Euclidean (L2) is used by default
36+
* @param[in] batch_size Batch size
37+
* @return Trustworthiness score
38+
* @note The constness of the data in X_embedded is currently casted away and the data is slightly
39+
* modified.
40+
*/
41+
double trustworthiness_score(
42+
raft::resources const& handle,
43+
raft::device_matrix_view<const float, int64_t, raft::row_major> X,
44+
raft::device_matrix_view<const float, int64_t, raft::row_major> X_embedded,
45+
int n_neighbors,
46+
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2SqrtUnexpanded,
47+
int batch_size = 512);
48+
49+
/** @} */ // end group stats_trustworthiness
50+
} // namespace stats
51+
} // namespace cuvs

0 commit comments

Comments
 (0)