Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 36 additions & 57 deletions cpp/src/umap/simpl_set_embed/algo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,48 +91,6 @@ void optimization_iteration_finalization(
seed += 1;
}

/**
* Update the embeddings and clear the buffers when using deterministic algorithm.
*/
template <typename T, typename nnz_t>
void apply_embedding_updates(T* head_embedding,
T* head_buffer,
int head_n,
T* tail_embedding,
T* tail_buffer,
int tail_n,
UMAPParams* params,
bool move_other,
rmm::cuda_stream_view stream)
{
ASSERT(params->deterministic, "Only used when deterministic is set to true.");
nnz_t n_components = params->n_components;
if (move_other) {
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator(0u),
thrust::make_counting_iterator(0u) + std::max(head_n, tail_n) * n_components,
[=] __device__(uint32_t i) {
if (i < head_n * n_components) {
head_embedding[i] += head_buffer[i];
head_buffer[i] = 0.0f;
}
if (i < tail_n * n_components) {
tail_embedding[i] += tail_buffer[i];
tail_buffer[i] = 0.0f;
}
});
} else {
// No need to update reference embedding
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator(0u),
thrust::make_counting_iterator(0u) + head_n * n_components,
[=] __device__(uint32_t i) {
head_embedding[i] += head_buffer[i];
head_buffer[i] = 0.0f;
});
}
}

/**
* \brief Constructs a rounding factor used to truncate elements in a sum such that the
* sum of the truncated elements is the same no matter what the order of the sum is.
Expand Down Expand Up @@ -268,13 +226,15 @@ void optimize_layout(T* head_embedding,
has_outlier = check_outliers<nnz_t, TPB_X>(tail, tail_n, nnz, threshold_for_outlier, stream);
}

if (has_outlier) {
if (has_outlier || params->deterministic) {
Copy link
Copy Markdown
Member

@dantegd dantegd Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: If deterministic=true but has_outlier=false then no additional chunking is applied (num_chunks stays at 1), but is there a chance that the outlier detection (check_outliers) may miss edge cases, since it is a heuristic at the end of the day?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is possible, but to prevent this we have to be overly conservative. We could default to a larger num_chunks for when deterministic=true (like 4 maybe?). This has been working well so far with the synthetic/real datasets that I've been working on, but you're right that it's difficult to be 100% confident that this will cover all edge cases.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it might be worth adding a super "strict" mode that always does this, so that if a user can turn it on explicitly, with documentation that it shouldn't be needed in general and just to be used as a "last resource"?

// Shuffling is necessary when outliers may be present (i.e., dense points that undergo many
// updates). It is critical to avoid having too many threads update the same embedding vector
// simultaneously, as this can affect correctness. By shuffling, potential outlier points are
// distributed across threads, rather than being processed by consecutive threads that are
// scheduled together. This approach relies on the GPU's inability to physically schedule all
// nnz edges at once.
// also shuffle when want deterministic behavior to ensure that updates for the same vertex are
// processed in different kernel launches.
auto first =
thrust::make_zip_iterator(thrust::make_tuple(thrust::device_pointer_cast(head),
thrust::device_pointer_cast(tail),
Expand All @@ -292,6 +252,16 @@ void optimize_layout(T* head_embedding,
}
}

if (has_outlier && params->deterministic) {
// for processing in deterministic mode on datasets that are likely to have outliers, we use the
// heuristic below to determine the number of chunks.
if (nnz > 100000) {
num_chunks = raft::ceildiv(nnz, static_cast<nnz_t>(100000));
} else if (nnz > 10000) {
num_chunks = raft::ceildiv(nnz, static_cast<nnz_t>(10000));
}
Comment thread
jinsolp marked this conversation as resolved.
}

rmm::device_uvector<T> epoch_of_next_negative_sample(nnz, stream);
T nsr_inv = T(1.0) / params->negative_sample_rate;
raft::linalg::unaryOp<T>(
Expand All @@ -307,21 +277,38 @@ void optimize_layout(T* head_embedding,
// Buffers used to store the gradient updates to avoid conflicts
rmm::device_uvector<T> head_buffer(0, stream_view);
rmm::device_uvector<T> tail_buffer(0, stream_view);

// Flags to track which vertices were modified per chunk (for sparse apply)
rmm::device_uvector<uint32_t> head_flags(0, stream_view);
rmm::device_uvector<uint32_t> tail_flags(0, stream_view);
// Write to embedding directly if deterministic is not needed.
T* d_head_buffer = head_embedding;
T* d_tail_buffer = tail_embedding;
T* d_head_buffer = head_embedding;
T* d_tail_buffer = tail_embedding;
uint32_t* d_head_flags = nullptr;
uint32_t* d_tail_flags = nullptr;
if (params->deterministic) {
nnz_t n_components = params->n_components;
head_buffer.resize(head_n * n_components, stream_view);
RAFT_CUDA_TRY(
cudaMemsetAsync(head_buffer.data(), '\0', sizeof(T) * head_buffer.size(), stream));

int head_flag_words = raft::ceildiv(head_n, 32);
head_flags.resize(head_flag_words, stream_view);
RAFT_CUDA_TRY(
cudaMemsetAsync(head_flags.data(), '\0', sizeof(uint32_t) * head_flag_words, stream));
d_head_buffer = head_buffer.data();
d_head_flags = head_flags.data();
// No need for tail if it's not being written.
if (move_other) {
tail_buffer.resize(tail_n * n_components, stream_view);
RAFT_CUDA_TRY(
cudaMemsetAsync(tail_buffer.data(), '\0', sizeof(T) * tail_buffer.size(), stream));
int tail_flag_words = raft::ceildiv(tail_n, 32);
tail_flags.resize(tail_flag_words, stream_view);
RAFT_CUDA_TRY(
cudaMemsetAsync(tail_flags.data(), '\0', sizeof(uint32_t) * tail_flag_words, stream));
d_tail_flags = tail_flags.data();
}
d_head_buffer = head_buffer.data();
Comment thread
dantegd marked this conversation as resolved.
d_tail_buffer = tail_buffer.data();
}

Expand All @@ -335,8 +322,11 @@ void optimize_layout(T* head_embedding,
for (int n = 0; n < n_epochs; n++) {
call_optimize_batch_kernel<T, nnz_t, TPB_X>(head_embedding,
d_head_buffer,
d_head_flags,
head_n,
tail_embedding,
d_tail_buffer,
d_tail_flags,
tail_n,
head,
tail,
Expand All @@ -354,17 +344,6 @@ void optimize_layout(T* head_embedding,
blk,
stream,
rounding);
if (params->deterministic) {
apply_embedding_updates<T, nnz_t>(head_embedding,
d_head_buffer,
head_n,
tail_embedding,
d_tail_buffer,
tail_n,
params,
move_other,
stream_view);
}
RAFT_CUDA_TRY(cudaGetLastError());
optimization_iteration_finalization(params, head_embedding, alpha, n, n_epochs, seed);
}
Expand Down
Loading