Skip to content

Commit e736d05

Browse files
authored
Fix UMAP outlier issue by checking for outliers and shuffling (#7131)
Closing #6454 Main difference between out simplicial set embedding and CPU UMAP was in negative sampling. We should use updated values (value after adding gradients) in the negative sampling stage. Dispatched to two kernels (and three usages) based on `n_components. Fixed like below. - `optimize_batch_kernel_reg` (`n_components=2`): update the `current_reg` register value (used later in the negative sampling stage) along with `grads` - `optimize_batch_kernel` (with shared memory): distinguish `current_buffer` (which used to JUST hold the gradient) from the `grad_buffer`. Now `current_buffer` and `grad_buffer` corresponds to the `current_reg` and `grads` registers in the register-approch kernel. - `optimize_batch_kernel` (without shared memory): untouched because the grads are applied directly to global memory. This updated value in global memory is read directly for negative sampling later on. ## Visualizations 2D 50K samples random selected for plotting. From the left - CPU KNN + CPU UMAP - GPU KNN + CPU UMAP - GPU KNN + GPU UMAP Before fix - GPU KNN + GPU UMAP After fix in this PR Using dataset 639K x 384 <img width="2400" height="600" alt="unique_embeddings_Beauty_comparison" src="https://github.com/user-attachments/assets/2b687c82-4a2d-4288-bcaa-d95d54a1b8ae" /> Using dataset 1.8M x 384 <img width="2400" height="600" alt="unique_embeddings_Appliances_comparison" src="https://github.com/user-attachments/assets/66e94360-6a55-4d37-8851-69c00e485685" /> ## Visualizations 3D 50K samples random selected for plotting. Plotting the same dataset with `n_components=3` (Which uses the second kernel). From the left - GPU KNN + CPU UMAP - GPU KNN + GPU UMAP Before fix - GPU KNN + GPU UMAP After fix in this PR Using dataset 639K x 384 (was already doing pretty well without outliers, still doing well) <img width="1905" height="666" alt="Screenshot 2025-08-25 at 1 16 37 PM" src="https://github.com/user-attachments/assets/edbfec64-ae9a-45f6-84b4-cc7e3c431884" /> Using dataset 1.8M x 384 before fix had outliers. <img width="1768" height="716" alt="Screenshot 2025-08-25 at 1 22 41 PM" src="https://github.com/user-attachments/assets/cfcffc8c-0ee3-4ad8-81f3-692483fec70e" /> Authors: - Jinsol Park (https://github.com/jinsolp) - Dante Gama Dessavre (https://github.com/dantegd) - Simon Adorf (https://github.com/csadorf) Approvers: - Victor Lafargue (https://github.com/viclafargue) - Divye Gala (https://github.com/divyegala) - Simon Adorf (https://github.com/csadorf) URL: #7131
1 parent e5adc43 commit e736d05

3 files changed

Lines changed: 185 additions & 21 deletions

File tree

cpp/src/umap/simpl_set_embed/algo.cuh

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,24 @@
2121
#include <cuml/common/logger.hpp>
2222
#include <cuml/manifold/umapparams.h>
2323

24+
#include <raft/linalg/init.cuh>
2425
#include <raft/linalg/unary_op.cuh>
2526
#include <raft/sparse/coo.hpp>
2627
#include <raft/sparse/op/filter.cuh>
2728
#include <raft/util/cudart_utils.hpp>
2829

30+
#include <rmm/device_scalar.hpp>
2931
#include <rmm/device_uvector.hpp>
3032
#include <rmm/exec_policy.hpp>
3133

34+
#include <thrust/device_ptr.h>
3235
#include <thrust/iterator/constant_iterator.h>
3336
#include <thrust/iterator/discard_iterator.h>
37+
#include <thrust/iterator/zip_iterator.h>
3438
#include <thrust/reduce.h>
39+
#include <thrust/shuffle.h>
3540
#include <thrust/system/cuda/execution_policy.h>
41+
#include <thrust/tuple.h>
3642

3743
#include <curand.h>
3844
#include <math.h>
@@ -185,6 +191,47 @@ T create_gradient_rounding_factor(
185191
return create_rounding_factor(max_abs, n_edges);
186192
}
187193

194+
template <typename nnz_t>
195+
CUML_KERNEL void compute_degrees_kernel(const int* rows, nnz_t nnz, int* degrees)
196+
{
197+
nnz_t i = blockIdx.x * blockDim.x + threadIdx.x;
198+
if (i < nnz) {
199+
int row = rows[i];
200+
atomicAdd(&degrees[row], 1);
201+
}
202+
}
203+
204+
CUML_KERNEL void check_threshold_kernel(const int* degrees,
205+
int n_vertices,
206+
int threshold,
207+
bool* flag)
208+
{
209+
int i = blockIdx.x * blockDim.x + threadIdx.x;
210+
if (i < n_vertices) {
211+
if (degrees[i] > threshold) { *flag = true; }
212+
}
213+
}
214+
215+
template <typename nnz_t, int TPB_X>
216+
bool check_outliers(const int* rows, int m, nnz_t nnz, int threshold, cudaStream_t stream)
217+
{
218+
rmm::device_uvector<int> graph_degree_head(m, stream);
219+
raft::linalg::zero(graph_degree_head.data(), m, stream);
220+
221+
dim3 grid_nnz(raft::ceildiv(nnz, static_cast<nnz_t>(TPB_X)), 1, 1);
222+
dim3 blk(TPB_X, 1, 1);
223+
compute_degrees_kernel<<<grid_nnz, blk, 0, stream>>>(rows, nnz, graph_degree_head.data());
224+
225+
rmm::device_scalar<bool> has_outlier_d(0, stream); // initialize to 0
226+
227+
dim3 grid_head_n(raft::ceildiv(static_cast<nnz_t>(m), static_cast<nnz_t>(TPB_X)), 1, 1);
228+
check_threshold_kernel<<<grid_head_n, blk, 0, stream>>>(
229+
graph_degree_head.data(), m, threshold, has_outlier_d.data());
230+
cudaStreamSynchronize(stream);
231+
bool has_outlier_h = has_outlier_d.value(stream);
232+
return has_outlier_h;
233+
}
234+
188235
/**
189236
* Runs gradient descent using sampling weights defined on
190237
* both the attraction and repulsion vectors.
@@ -199,8 +246,8 @@ void optimize_layout(T* head_embedding,
199246
int head_n,
200247
T* tail_embedding,
201248
int tail_n,
202-
const int* head,
203-
const int* tail,
249+
int* head,
250+
int* tail,
204251
nnz_t nnz,
205252
T* epochs_per_sample,
206253
float gamma,
@@ -213,6 +260,39 @@ void optimize_layout(T* head_embedding,
213260
T alpha = params->initial_alpha;
214261

215262
auto stream_view = rmm::cuda_stream_view(stream);
263+
264+
T rounding = create_gradient_rounding_factor<T, nnz_t>(head, nnz, head_n, alpha, stream_view);
265+
266+
auto min_n = min(head_n, tail_n);
267+
int threshold_for_outlier = 1024; // this is a heuristic value.
268+
// for smaller datasets, could be a dense point even with a smaller graph degree
269+
if (min_n <= 100000) {
270+
threshold_for_outlier = 256;
271+
} else if (min_n <= 1000000) {
272+
threshold_for_outlier = 512;
273+
}
274+
275+
bool has_outlier = check_outliers<nnz_t, TPB_X>(head, head_n, nnz, threshold_for_outlier, stream);
276+
if (move_other && !has_outlier) {
277+
has_outlier = check_outliers<nnz_t, TPB_X>(tail, tail_n, nnz, threshold_for_outlier, stream);
278+
}
279+
280+
if (has_outlier) {
281+
// Shuffling is necessary when outliers may be present (i.e., dense points that undergo many
282+
// updates). It is critical to avoid having too many threads update the same embedding vector
283+
// simultaneously, as this can affect correctness. By shuffling, potential outlier points are
284+
// distributed across threads, rather than being processed by consecutive threads that are
285+
// scheduled together. This approach relies on the GPU's inability to physically schedule all
286+
// nnz edges at once.
287+
auto first =
288+
thrust::make_zip_iterator(thrust::make_tuple(thrust::device_pointer_cast(head),
289+
thrust::device_pointer_cast(tail),
290+
thrust::device_pointer_cast(epochs_per_sample)));
291+
292+
thrust::default_random_engine rng(params->random_state);
293+
thrust::shuffle(first, first + nnz, rng);
294+
}
295+
216296
rmm::device_uvector<T> epoch_of_next_negative_sample(nnz, stream);
217297
T nsr_inv = T(1.0) / params->negative_sample_rate;
218298
raft::linalg::unaryOp<T>(
@@ -250,8 +330,6 @@ void optimize_layout(T* head_embedding,
250330
dim3 blk(TPB_X, 1, 1);
251331
uint64_t seed = params->random_state;
252332

253-
T rounding = create_gradient_rounding_factor<T, nnz_t>(head, nnz, head_n, alpha, stream_view);
254-
255333
for (int n = 0; n < n_epochs; n++) {
256334
call_optimize_batch_kernel<T, nnz_t, TPB_X>(head_embedding,
257335
d_head_buffer,

cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding,
156156
for (int d = 0; d < n_components; d++) {
157157
auto diff = current_reg[d] - other_reg[d];
158158
auto grad_d = clip<T>(attractive_grad_coeff * diff, T(-4.0), T(4.0));
159-
grads[d] = grad_d * alpha;
159+
current_reg[d] += grad_d * alpha;
160+
grads[d] = grad_d * alpha;
160161
}
161162
// storing gradients for negative samples back to global memory
162163
if (move_other) {
@@ -200,6 +201,7 @@ CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding,
200201
grad_d = clip<T>(repulsive_grad_coeff * diff, T(-4.0), T(4.0));
201202
else
202203
grad_d = T(4.0);
204+
current_reg[d] += grad_d * alpha;
203205
grads[d] += grad_d * alpha;
204206
}
205207
}
@@ -252,8 +254,17 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding,
252254
T* cur_write = head_buffer + (j * n_components);
253255
T* oth_write = tail_buffer + (k * n_components);
254256

257+
// for reducing access to global memory. load values from global memory, and accumulate grads onto
258+
// this shared memory position instead of reading from global memory every time.
255259
T* current_buffer{nullptr};
256-
if (use_shared_mem) { current_buffer = (T*)embedding_shared_mem_updates + threadIdx.x; }
260+
// for keeping track of grads, final write to global memory
261+
T* grads_buffer{nullptr};
262+
if constexpr (use_shared_mem) {
263+
// n_components for thread0, then the next n_components for thread1 ...
264+
current_buffer = (T*)embedding_shared_mem_updates + threadIdx.x * n_components;
265+
// TPB_X for first component, then another TPB_X for the next component for better coalescing...
266+
grads_buffer = (T*)embedding_shared_mem_updates + TPB_X * n_components + threadIdx.x;
267+
}
257268
auto dist_squared = rdist<T>(current, other, n_components);
258269
// Attractive force between the two vertices, since they
259270
// are connected by an edge in the 1-skeleton.
@@ -267,10 +278,13 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding,
267278
* performing unsupervised training).
268279
*/
269280
for (int d = 0; d < n_components; d++) {
270-
auto grad_d = clip<T>(attractive_grad_coeff * (current[d] - other[d]), T(-4.0), T(4.0));
281+
T current_val = current[d];
282+
if constexpr (use_shared_mem) { current_buffer[d] = current_val; }
283+
auto grad_d = clip<T>(attractive_grad_coeff * (current_val - other[d]), T(-4.0), T(4.0));
271284
grad_d *= alpha;
272-
if (use_shared_mem) {
273-
current_buffer[d * TPB_X] = grad_d;
285+
if constexpr (use_shared_mem) {
286+
current_buffer[d] += grad_d;
287+
grads_buffer[d * TPB_X] = grad_d;
274288
} else {
275289
raft::myAtomicAdd<T>((T*)cur_write + d, truncate_gradient(rounding, grad_d));
276290
if (move_other) { // happens only during unsupervised training
@@ -282,7 +296,7 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding,
282296
if (use_shared_mem && move_other) {
283297
__syncthreads();
284298
for (int d = 0; d < n_components; d++) {
285-
auto grad = current_buffer[d * TPB_X];
299+
auto grad = grads_buffer[d * TPB_X];
286300
raft::myAtomicAdd<T>((T*)oth_write + d, truncate_gradient(rounding, -grad));
287301
}
288302
}
@@ -299,7 +313,11 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding,
299313
gen.next(r);
300314
nnz_t t = r % tail_n;
301315
T const* negative_sample = tail_embedding + (t * n_components);
302-
dist_squared = rdist<T>(current, negative_sample, n_components);
316+
if constexpr (use_shared_mem) {
317+
dist_squared = rdist<T>(current_buffer, negative_sample, n_components);
318+
} else {
319+
dist_squared = rdist<T>(current, negative_sample, n_components);
320+
}
303321
// repulsive force between two vertices
304322
auto repulsive_grad_coeff = T(0.0);
305323
if (dist_squared > T(0.0)) {
@@ -313,25 +331,31 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding,
313331
*/
314332
for (int d = 0; d < n_components; d++) {
315333
auto grad_d = T(0.0);
316-
if (repulsive_grad_coeff > T(0.0))
317-
grad_d = clip<T>(repulsive_grad_coeff * (current[d] - negative_sample[d]), T(-4.0), T(4.0));
318-
else
334+
if (repulsive_grad_coeff > T(0.0)) {
335+
if constexpr (use_shared_mem) {
336+
grad_d = clip<T>(
337+
repulsive_grad_coeff * (current_buffer[d] - negative_sample[d]), T(-4.0), T(4.0));
338+
} else {
339+
grad_d =
340+
clip<T>(repulsive_grad_coeff * (current[d] - negative_sample[d]), T(-4.0), T(4.0));
341+
}
342+
} else
319343
grad_d = T(4.0);
320344
grad_d *= alpha;
321-
if (use_shared_mem) {
322-
current_buffer[d * TPB_X] += grad_d;
345+
if constexpr (use_shared_mem) {
346+
current_buffer[d] += grad_d;
347+
grads_buffer[d * TPB_X] += grad_d;
323348
} else {
324349
raft::myAtomicAdd<T>((T*)cur_write + d, truncate_gradient(rounding, grad_d));
325350
}
326351
}
327352
}
328353

329354
// storing gradients for positive samples back to global memory
330-
if (use_shared_mem) {
355+
if constexpr (use_shared_mem) {
331356
__syncthreads();
332357
for (int d = 0; d < n_components; d++) {
333-
raft::myAtomicAdd<T>((T*)cur_write + d,
334-
truncate_gradient(rounding, current_buffer[d * TPB_X]));
358+
raft::myAtomicAdd<T>((T*)cur_write + d, truncate_gradient(rounding, grads_buffer[d * TPB_X]));
335359
}
336360
}
337361
epoch_of_next_negative_sample[row] =
@@ -373,7 +397,7 @@ void call_optimize_batch_kernel(T const* head_embedding,
373397
cudaStream_t& stream,
374398
T rounding)
375399
{
376-
std::size_t requiredSize = TPB_X * params->n_components;
400+
std::size_t requiredSize = TPB_X * params->n_components * 2;
377401
requiredSize *= sizeof(T);
378402
bool use_shared_mem = requiredSize < static_cast<std::size_t>(raft::getSharedMemPerBlock());
379403
T nsr_inv = T(1.0) / params->negative_sample_rate;

python/cuml/tests/test_umap.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from pylibraft.common import DeviceResourcesSNMG
3030
from sklearn import datasets
3131
from sklearn.cluster import KMeans
32-
from sklearn.datasets import make_blobs
32+
from sklearn.datasets import make_blobs, make_moons
3333
from sklearn.manifold import trustworthiness
3434
from sklearn.metrics import adjusted_rand_score
3535
from sklearn.neighbors import NearestNeighbors
@@ -924,3 +924,65 @@ def test_umap_small_fit_large_transform():
924924

925925
trust = trustworthiness(infer, embeddings, n_neighbors=10)
926926
assert trust >= 0.9
927+
928+
929+
@pytest.mark.parametrize("n_neighbors", [5, 15])
930+
@pytest.mark.parametrize("n_components", [2, 5])
931+
def test_umap_outliers(n_neighbors, n_components):
932+
all_neighbors = pytest.importorskip("cuvs.neighbors.all_neighbors")
933+
nn_descent = pytest.importorskip("cuvs.neighbors.nn_descent")
934+
935+
k = n_neighbors
936+
n_rows = 50_000
937+
938+
# This dataset was specifically chosen because UMAP produces outliers
939+
# on this dataset before the outlier fix.
940+
data, _ = make_moons(n_samples=n_rows, noise=0.0, random_state=42)
941+
data = data.astype(np.float32)
942+
943+
# precompute knn for faster testing with CPU UMAP
944+
nn_descent_params = nn_descent.IndexParams(
945+
metric="euclidean",
946+
graph_degree=k,
947+
intermediate_graph_degree=k * 2,
948+
)
949+
params = all_neighbors.AllNeighborsParams(
950+
algo="nn_descent",
951+
metric="euclidean",
952+
nn_descent_params=nn_descent_params,
953+
)
954+
indices, distances = all_neighbors.build(
955+
data,
956+
k,
957+
params,
958+
distances=cp.empty((n_rows, k), dtype=cp.float32),
959+
)
960+
indices = cp.asnumpy(indices)
961+
distances = cp.asnumpy(distances)
962+
963+
gpu_umap = cuUMAP(
964+
precomputed_knn=(indices, distances),
965+
build_algo="nn_descent",
966+
init="spectral",
967+
n_neighbors=n_neighbors,
968+
n_components=n_components,
969+
)
970+
gpu_umap_embeddings = gpu_umap.fit_transform(data)
971+
972+
cpu_umap = umap.UMAP(
973+
precomputed_knn=(indices, distances),
974+
init="spectral",
975+
n_neighbors=n_neighbors,
976+
n_components=n_components,
977+
)
978+
cpu_umap_embeddings = cpu_umap.fit_transform(data)
979+
980+
# test to see if there are values in the final embedding that are too out of range
981+
# compared to the cpu umap output.
982+
lower_bound = 3 * cpu_umap_embeddings.min()
983+
upper_bound = 3 * cpu_umap_embeddings.max()
984+
985+
assert np.all(
986+
(gpu_umap_embeddings >= lower_bound)
987+
& (gpu_umap_embeddings <= upper_bound)
988+
)

0 commit comments

Comments
 (0)