Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 75 additions & 4 deletions cpp/src/umap/simpl_set_embed/algo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@
#include <rmm/device_uvector.hpp>
#include <rmm/exec_policy.hpp>

#include <thrust/device_ptr.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/reduce.h>
#include <thrust/shuffle.h>
#include <thrust/system/cuda/execution_policy.h>
#include <thrust/tuple.h>

#include <curand.h>
#include <math.h>
Expand Down Expand Up @@ -185,6 +189,50 @@ T create_gradient_rounding_factor(
return create_rounding_factor(max_abs, n_edges);
}

template <typename nnz_t>
CUML_KERNEL void compute_degrees_kernel(const int* rows, nnz_t nnz, int* degrees)
{
nnz_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < nnz) {
int row = rows[i];
atomicAdd(&degrees[row], 1);
}
}

CUML_KERNEL void check_threshold_kernel(const int* degrees,
int n_vertices,
int threshold,
int* flag)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n_vertices) {
if (degrees[i] > threshold) { atomicExch(flag, 1); }
Comment thread
jinsolp marked this conversation as resolved.
Outdated
}
}

template <typename nnz_t, int TPB_X>
bool check_outliers(const int* rows, int m, nnz_t nnz, int threshold, cudaStream_t stream)
{
rmm::device_uvector<int> graph_degree_head(m, stream);
cudaMemset(graph_degree_head.data(), 0, m * sizeof(int));
Comment thread
jinsolp marked this conversation as resolved.
Outdated

dim3 grid_nnz(raft::ceildiv(nnz, static_cast<nnz_t>(TPB_X)), 1, 1);
dim3 blk(TPB_X, 1, 1);
compute_degrees_kernel<<<grid_nnz, blk, 0, stream>>>(rows, nnz, graph_degree_head.data());

rmm::device_uvector<int> has_outlier_d(1, stream);
cudaMemset(has_outlier_d.data(), 0, sizeof(int));
Comment thread
jinsolp marked this conversation as resolved.
Outdated
// has_outlier_d.set_value_async(0, stream);

dim3 grid_head_n(raft::ceildiv(static_cast<nnz_t>(m), static_cast<nnz_t>(TPB_X)), 1, 1);
check_threshold_kernel<<<grid_head_n, blk, 0, stream>>>(
graph_degree_head.data(), m, threshold, has_outlier_d.data());

int has_outlier_h = 0;
raft::copy(&has_outlier_h, has_outlier_d.data(), 1, stream);
return static_cast<bool>(has_outlier_h);
}

/**
* Runs gradient descent using sampling weights defined on
* both the attraction and repulsion vectors.
Expand All @@ -199,8 +247,8 @@ void optimize_layout(T* head_embedding,
int head_n,
T* tail_embedding,
int tail_n,
const int* head,
const int* tail,
int* head,
int* tail,
nnz_t nnz,
T* epochs_per_sample,
float gamma,
Expand All @@ -213,6 +261,31 @@ void optimize_layout(T* head_embedding,
T alpha = params->initial_alpha;

auto stream_view = rmm::cuda_stream_view(stream);

T rounding = create_gradient_rounding_factor<T, nnz_t>(head, nnz, head_n, alpha, stream_view);

int threshold_for_outlier = 1024; // this is a heuristic value.
bool has_outlier = check_outliers<nnz_t, TPB_X>(head, head_n, nnz, threshold_for_outlier, stream);
if (move_other && !has_outlier) {
has_outlier = check_outliers<nnz_t, TPB_X>(tail, tail_n, nnz, threshold_for_outlier, stream);
}

if (has_outlier) {
// Shuffling is necessary when outliers may be present (i.e., dense points that undergo many
// updates). It is critical to avoid having too many threads update the same embedding vector
// simultaneously, as this can affect correctness. By shuffling, potential outlier points are
// distributed across threads, rather than being processed by consecutive threads that are
// scheduled together. This approach relies on the GPU's inability to physically schedule all
// nnz edges at once.
auto first =
thrust::make_zip_iterator(thrust::make_tuple(thrust::device_pointer_cast(head),
thrust::device_pointer_cast(tail),
thrust::device_pointer_cast(epochs_per_sample)));

thrust::default_random_engine rng(params->random_state);
thrust::shuffle(first, first + nnz, rng);
Comment thread
jinsolp marked this conversation as resolved.
}

rmm::device_uvector<T> epoch_of_next_negative_sample(nnz, stream);
T nsr_inv = T(1.0) / params->negative_sample_rate;
raft::linalg::unaryOp<T>(
Expand Down Expand Up @@ -250,8 +323,6 @@ void optimize_layout(T* head_embedding,
dim3 blk(TPB_X, 1, 1);
uint64_t seed = params->random_state;

T rounding = create_gradient_rounding_factor<T, nnz_t>(head, nnz, head_n, alpha, stream_view);

for (int n = 0; n < n_epochs; n++) {
call_optimize_batch_kernel<T, nnz_t, TPB_X>(head_embedding,
d_head_buffer,
Expand Down
56 changes: 40 additions & 16 deletions cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding,
for (int d = 0; d < n_components; d++) {
auto diff = current_reg[d] - other_reg[d];
auto grad_d = clip<T>(attractive_grad_coeff * diff, T(-4.0), T(4.0));
grads[d] = grad_d * alpha;
current_reg[d] += grad_d * alpha;
grads[d] = grad_d * alpha;
}
// storing gradients for negative samples back to global memory
if (move_other) {
Expand Down Expand Up @@ -200,6 +201,7 @@ CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding,
grad_d = clip<T>(repulsive_grad_coeff * diff, T(-4.0), T(4.0));
else
grad_d = T(4.0);
current_reg[d] += grad_d * alpha;
grads[d] += grad_d * alpha;
}
}
Expand Down Expand Up @@ -252,8 +254,17 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding,
T* cur_write = head_buffer + (j * n_components);
T* oth_write = tail_buffer + (k * n_components);

// for reducing access to global memory. load values from global memory, and accumulate grads onto
// this shared memory position instead of reading from global memory every time.
T* current_buffer{nullptr};
if (use_shared_mem) { current_buffer = (T*)embedding_shared_mem_updates + threadIdx.x; }
// for keeping track of grads, final write to global memory
T* grads_buffer{nullptr};
if constexpr (use_shared_mem) {
Comment thread
jinsolp marked this conversation as resolved.
// n_components for thread0, then the next n_components for thread1 ...
current_buffer = (T*)embedding_shared_mem_updates + threadIdx.x * n_components;
// TPB_X for first component, then another TPB_X for the next component for better coalescing...
grads_buffer = (T*)embedding_shared_mem_updates + TPB_X * n_components + threadIdx.x;
Comment thread
jinsolp marked this conversation as resolved.
}
auto dist_squared = rdist<T>(current, other, n_components);
// Attractive force between the two vertices, since they
// are connected by an edge in the 1-skeleton.
Expand All @@ -267,10 +278,13 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding,
* performing unsupervised training).
*/
for (int d = 0; d < n_components; d++) {
auto grad_d = clip<T>(attractive_grad_coeff * (current[d] - other[d]), T(-4.0), T(4.0));
T current_val = current[d];
if constexpr (use_shared_mem) { current_buffer[d] = current_val; }
auto grad_d = clip<T>(attractive_grad_coeff * (current_val - other[d]), T(-4.0), T(4.0));
grad_d *= alpha;
if (use_shared_mem) {
current_buffer[d * TPB_X] = grad_d;
if constexpr (use_shared_mem) {
current_buffer[d] += grad_d;
grads_buffer[d * TPB_X] = grad_d;
} else {
raft::myAtomicAdd<T>((T*)cur_write + d, truncate_gradient(rounding, grad_d));
if (move_other) { // happens only during unsupervised training
Expand All @@ -282,7 +296,7 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding,
if (use_shared_mem && move_other) {
__syncthreads();
for (int d = 0; d < n_components; d++) {
auto grad = current_buffer[d * TPB_X];
auto grad = grads_buffer[d * TPB_X];
raft::myAtomicAdd<T>((T*)oth_write + d, truncate_gradient(rounding, -grad));
}
}
Expand All @@ -299,7 +313,11 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding,
gen.next(r);
nnz_t t = r % tail_n;
T const* negative_sample = tail_embedding + (t * n_components);
dist_squared = rdist<T>(current, negative_sample, n_components);
if constexpr (use_shared_mem) {
dist_squared = rdist<T>(current_buffer, negative_sample, n_components);
} else {
dist_squared = rdist<T>(current, negative_sample, n_components);
}
// repulsive force between two vertices
auto repulsive_grad_coeff = T(0.0);
if (dist_squared > T(0.0)) {
Expand All @@ -313,25 +331,31 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding,
*/
for (int d = 0; d < n_components; d++) {
auto grad_d = T(0.0);
if (repulsive_grad_coeff > T(0.0))
grad_d = clip<T>(repulsive_grad_coeff * (current[d] - negative_sample[d]), T(-4.0), T(4.0));
else
if (repulsive_grad_coeff > T(0.0)) {
if constexpr (use_shared_mem) {
grad_d = clip<T>(
repulsive_grad_coeff * (current_buffer[d] - negative_sample[d]), T(-4.0), T(4.0));
} else {
grad_d =
clip<T>(repulsive_grad_coeff * (current[d] - negative_sample[d]), T(-4.0), T(4.0));
}
} else
grad_d = T(4.0);
grad_d *= alpha;
if (use_shared_mem) {
current_buffer[d * TPB_X] += grad_d;
if constexpr (use_shared_mem) {
current_buffer[d] += grad_d;
grads_buffer[d * TPB_X] += grad_d;
} else {
raft::myAtomicAdd<T>((T*)cur_write + d, truncate_gradient(rounding, grad_d));
}
}
}

// storing gradients for positive samples back to global memory
if (use_shared_mem) {
if constexpr (use_shared_mem) {
__syncthreads();
for (int d = 0; d < n_components; d++) {
raft::myAtomicAdd<T>((T*)cur_write + d,
truncate_gradient(rounding, current_buffer[d * TPB_X]));
raft::myAtomicAdd<T>((T*)cur_write + d, truncate_gradient(rounding, grads_buffer[d * TPB_X]));
Comment on lines -334 to +358
Copy link
Copy Markdown
Contributor

@viclafargue viclafargue Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Importantly, when random_state is set, current != cur_write and other != oth_write as updates accumulate in separate buffer to allow high precision deterministic accumulation of updates. It looks like we may still have outliers in this case? But, I guess that is acceptable for now.

}
}
epoch_of_next_negative_sample[row] =
Expand Down Expand Up @@ -373,7 +397,7 @@ void call_optimize_batch_kernel(T const* head_embedding,
cudaStream_t& stream,
T rounding)
{
std::size_t requiredSize = TPB_X * params->n_components;
std::size_t requiredSize = TPB_X * params->n_components * 2;
Comment thread
jinsolp marked this conversation as resolved.
requiredSize *= sizeof(T);
bool use_shared_mem = requiredSize < static_cast<std::size_t>(raft::getSharedMemPerBlock());
T nsr_inv = T(1.0) / params->negative_sample_rate;
Expand Down
Loading