-
Notifications
You must be signed in to change notification settings - Fork 623
Fix UMAP outlier issue by checking for outliers and shuffling #7131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
6e71fce
4055864
eae41a5
c00cd27
6bf6c78
6a3d96c
a5af5c0
4e4e5dd
9cc2b9a
5cf16e3
b474e8c
10c8637
52b6f1c
95211f7
823fb73
4953056
0461dea
4a91537
edae314
e3d8c83
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -156,7 +156,8 @@ CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding, | |
| for (int d = 0; d < n_components; d++) { | ||
| auto diff = current_reg[d] - other_reg[d]; | ||
| auto grad_d = clip<T>(attractive_grad_coeff * diff, T(-4.0), T(4.0)); | ||
| grads[d] = grad_d * alpha; | ||
| current_reg[d] += grad_d * alpha; | ||
| grads[d] = grad_d * alpha; | ||
| } | ||
| // storing gradients for negative samples back to global memory | ||
| if (move_other) { | ||
|
|
@@ -200,6 +201,7 @@ CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding, | |
| grad_d = clip<T>(repulsive_grad_coeff * diff, T(-4.0), T(4.0)); | ||
| else | ||
| grad_d = T(4.0); | ||
| current_reg[d] += grad_d * alpha; | ||
| grads[d] += grad_d * alpha; | ||
| } | ||
| } | ||
|
|
@@ -252,8 +254,17 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, | |
| T* cur_write = head_buffer + (j * n_components); | ||
| T* oth_write = tail_buffer + (k * n_components); | ||
|
|
||
| // for reducing access to global memory. load values from global memory, and accumulate grads onto | ||
| // this shared memory position instead of reading from global memory every time. | ||
| T* current_buffer{nullptr}; | ||
| if (use_shared_mem) { current_buffer = (T*)embedding_shared_mem_updates + threadIdx.x; } | ||
| // for keeping track of grads, final write to global memory | ||
| T* grads_buffer{nullptr}; | ||
| if constexpr (use_shared_mem) { | ||
|
jinsolp marked this conversation as resolved.
|
||
| // n_components for thread0, then the next n_components for thread1 ... | ||
| current_buffer = (T*)embedding_shared_mem_updates + threadIdx.x * n_components; | ||
| // TPB_X for first component, then another TPB_X for the next component for better coalescing... | ||
| grads_buffer = (T*)embedding_shared_mem_updates + TPB_X * n_components + threadIdx.x; | ||
|
jinsolp marked this conversation as resolved.
|
||
| } | ||
| auto dist_squared = rdist<T>(current, other, n_components); | ||
| // Attractive force between the two vertices, since they | ||
| // are connected by an edge in the 1-skeleton. | ||
|
|
@@ -267,10 +278,13 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, | |
| * performing unsupervised training). | ||
| */ | ||
| for (int d = 0; d < n_components; d++) { | ||
| auto grad_d = clip<T>(attractive_grad_coeff * (current[d] - other[d]), T(-4.0), T(4.0)); | ||
| T current_val = current[d]; | ||
| if constexpr (use_shared_mem) { current_buffer[d] = current_val; } | ||
| auto grad_d = clip<T>(attractive_grad_coeff * (current_val - other[d]), T(-4.0), T(4.0)); | ||
| grad_d *= alpha; | ||
| if (use_shared_mem) { | ||
| current_buffer[d * TPB_X] = grad_d; | ||
| if constexpr (use_shared_mem) { | ||
| current_buffer[d] += grad_d; | ||
| grads_buffer[d * TPB_X] = grad_d; | ||
| } else { | ||
| raft::myAtomicAdd<T>((T*)cur_write + d, truncate_gradient(rounding, grad_d)); | ||
| if (move_other) { // happens only during unsupervised training | ||
|
|
@@ -282,7 +296,7 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, | |
| if (use_shared_mem && move_other) { | ||
| __syncthreads(); | ||
| for (int d = 0; d < n_components; d++) { | ||
| auto grad = current_buffer[d * TPB_X]; | ||
| auto grad = grads_buffer[d * TPB_X]; | ||
| raft::myAtomicAdd<T>((T*)oth_write + d, truncate_gradient(rounding, -grad)); | ||
| } | ||
| } | ||
|
|
@@ -299,7 +313,11 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, | |
| gen.next(r); | ||
| nnz_t t = r % tail_n; | ||
| T const* negative_sample = tail_embedding + (t * n_components); | ||
| dist_squared = rdist<T>(current, negative_sample, n_components); | ||
| if constexpr (use_shared_mem) { | ||
| dist_squared = rdist<T>(current_buffer, negative_sample, n_components); | ||
| } else { | ||
| dist_squared = rdist<T>(current, negative_sample, n_components); | ||
| } | ||
| // repulsive force between two vertices | ||
| auto repulsive_grad_coeff = T(0.0); | ||
| if (dist_squared > T(0.0)) { | ||
|
|
@@ -313,25 +331,31 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, | |
| */ | ||
| for (int d = 0; d < n_components; d++) { | ||
| auto grad_d = T(0.0); | ||
| if (repulsive_grad_coeff > T(0.0)) | ||
| grad_d = clip<T>(repulsive_grad_coeff * (current[d] - negative_sample[d]), T(-4.0), T(4.0)); | ||
| else | ||
| if (repulsive_grad_coeff > T(0.0)) { | ||
| if constexpr (use_shared_mem) { | ||
| grad_d = clip<T>( | ||
| repulsive_grad_coeff * (current_buffer[d] - negative_sample[d]), T(-4.0), T(4.0)); | ||
| } else { | ||
| grad_d = | ||
| clip<T>(repulsive_grad_coeff * (current[d] - negative_sample[d]), T(-4.0), T(4.0)); | ||
| } | ||
| } else | ||
| grad_d = T(4.0); | ||
| grad_d *= alpha; | ||
| if (use_shared_mem) { | ||
| current_buffer[d * TPB_X] += grad_d; | ||
| if constexpr (use_shared_mem) { | ||
| current_buffer[d] += grad_d; | ||
| grads_buffer[d * TPB_X] += grad_d; | ||
| } else { | ||
| raft::myAtomicAdd<T>((T*)cur_write + d, truncate_gradient(rounding, grad_d)); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // storing gradients for positive samples back to global memory | ||
| if (use_shared_mem) { | ||
| if constexpr (use_shared_mem) { | ||
| __syncthreads(); | ||
| for (int d = 0; d < n_components; d++) { | ||
| raft::myAtomicAdd<T>((T*)cur_write + d, | ||
| truncate_gradient(rounding, current_buffer[d * TPB_X])); | ||
| raft::myAtomicAdd<T>((T*)cur_write + d, truncate_gradient(rounding, grads_buffer[d * TPB_X])); | ||
|
Comment on lines
-334
to
+358
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Importantly, when |
||
| } | ||
| } | ||
| epoch_of_next_negative_sample[row] = | ||
|
|
@@ -373,7 +397,7 @@ void call_optimize_batch_kernel(T const* head_embedding, | |
| cudaStream_t& stream, | ||
| T rounding) | ||
| { | ||
| std::size_t requiredSize = TPB_X * params->n_components; | ||
| std::size_t requiredSize = TPB_X * params->n_components * 2; | ||
|
jinsolp marked this conversation as resolved.
|
||
| requiredSize *= sizeof(T); | ||
| bool use_shared_mem = requiredSize < static_cast<std::size_t>(raft::getSharedMemPerBlock()); | ||
| T nsr_inv = T(1.0) / params->negative_sample_rate; | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.