@@ -156,7 +156,8 @@ CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding,
156156 for (int d = 0 ; d < n_components; d++) {
157157 auto diff = current_reg[d] - other_reg[d];
158158 auto grad_d = clip<T>(attractive_grad_coeff * diff, T (-4.0 ), T (4.0 ));
159- grads[d] = grad_d * alpha;
159+ current_reg[d] += grad_d * alpha;
160+ grads[d] = grad_d * alpha;
160161 }
161162 // storing gradients for negative samples back to global memory
162163 if (move_other) {
@@ -200,6 +201,7 @@ CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding,
200201 grad_d = clip<T>(repulsive_grad_coeff * diff, T (-4.0 ), T (4.0 ));
201202 else
202203 grad_d = T (4.0 );
204+ current_reg[d] += grad_d * alpha;
203205 grads[d] += grad_d * alpha;
204206 }
205207 }
@@ -252,8 +254,17 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding,
252254 T* cur_write = head_buffer + (j * n_components);
253255 T* oth_write = tail_buffer + (k * n_components);
254256
257+ // for reducing access to global memory. load values from global memory, and accumulate grads onto
258+ // this shared memory position instead of reading from global memory every time.
255259 T* current_buffer{nullptr };
256- if (use_shared_mem) { current_buffer = (T*)embedding_shared_mem_updates + threadIdx .x ; }
260+ // for keeping track of grads, final write to global memory
261+ T* grads_buffer{nullptr };
262+ if constexpr (use_shared_mem) {
263+ // n_components for thread0, then the next n_components for thread1 ...
264+ current_buffer = (T*)embedding_shared_mem_updates + threadIdx .x * n_components;
265+ // TPB_X for first component, then another TPB_X for the next component for better coalescing...
266+ grads_buffer = (T*)embedding_shared_mem_updates + TPB_X * n_components + threadIdx .x ;
267+ }
257268 auto dist_squared = rdist<T>(current, other, n_components);
258269 // Attractive force between the two vertices, since they
259270 // are connected by an edge in the 1-skeleton.
@@ -267,10 +278,13 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding,
267278 * performing unsupervised training).
268279 */
269280 for (int d = 0 ; d < n_components; d++) {
270- auto grad_d = clip<T>(attractive_grad_coeff * (current[d] - other[d]), T (-4.0 ), T (4.0 ));
281+ T current_val = current[d];
282+ if constexpr (use_shared_mem) { current_buffer[d] = current_val; }
283+ auto grad_d = clip<T>(attractive_grad_coeff * (current_val - other[d]), T (-4.0 ), T (4.0 ));
271284 grad_d *= alpha;
272- if (use_shared_mem) {
273- current_buffer[d * TPB_X] = grad_d;
285+ if constexpr (use_shared_mem) {
286+ current_buffer[d] += grad_d;
287+ grads_buffer[d * TPB_X] = grad_d;
274288 } else {
275289 raft::myAtomicAdd<T>((T*)cur_write + d, truncate_gradient (rounding, grad_d));
276290 if (move_other) { // happens only during unsupervised training
@@ -282,7 +296,7 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding,
282296 if (use_shared_mem && move_other) {
283297 __syncthreads ();
284298 for (int d = 0 ; d < n_components; d++) {
285- auto grad = current_buffer [d * TPB_X];
299+ auto grad = grads_buffer [d * TPB_X];
286300 raft::myAtomicAdd<T>((T*)oth_write + d, truncate_gradient (rounding, -grad));
287301 }
288302 }
@@ -299,7 +313,11 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding,
299313 gen.next (r);
300314 nnz_t t = r % tail_n;
301315 T const * negative_sample = tail_embedding + (t * n_components);
302- dist_squared = rdist<T>(current, negative_sample, n_components);
316+ if constexpr (use_shared_mem) {
317+ dist_squared = rdist<T>(current_buffer, negative_sample, n_components);
318+ } else {
319+ dist_squared = rdist<T>(current, negative_sample, n_components);
320+ }
303321 // repulsive force between two vertices
304322 auto repulsive_grad_coeff = T (0.0 );
305323 if (dist_squared > T (0.0 )) {
@@ -313,25 +331,31 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding,
313331 */
314332 for (int d = 0 ; d < n_components; d++) {
315333 auto grad_d = T (0.0 );
316- if (repulsive_grad_coeff > T (0.0 ))
317- grad_d = clip<T>(repulsive_grad_coeff * (current[d] - negative_sample[d]), T (-4.0 ), T (4.0 ));
318- else
334+ if (repulsive_grad_coeff > T (0.0 )) {
335+ if constexpr (use_shared_mem) {
336+ grad_d = clip<T>(
337+ repulsive_grad_coeff * (current_buffer[d] - negative_sample[d]), T (-4.0 ), T (4.0 ));
338+ } else {
339+ grad_d =
340+ clip<T>(repulsive_grad_coeff * (current[d] - negative_sample[d]), T (-4.0 ), T (4.0 ));
341+ }
342+ } else
319343 grad_d = T (4.0 );
320344 grad_d *= alpha;
321- if (use_shared_mem) {
322- current_buffer[d * TPB_X] += grad_d;
345+ if constexpr (use_shared_mem) {
346+ current_buffer[d] += grad_d;
347+ grads_buffer[d * TPB_X] += grad_d;
323348 } else {
324349 raft::myAtomicAdd<T>((T*)cur_write + d, truncate_gradient (rounding, grad_d));
325350 }
326351 }
327352 }
328353
329354 // storing gradients for positive samples back to global memory
330- if (use_shared_mem) {
355+ if constexpr (use_shared_mem) {
331356 __syncthreads ();
332357 for (int d = 0 ; d < n_components; d++) {
333- raft::myAtomicAdd<T>((T*)cur_write + d,
334- truncate_gradient (rounding, current_buffer[d * TPB_X]));
358+ raft::myAtomicAdd<T>((T*)cur_write + d, truncate_gradient (rounding, grads_buffer[d * TPB_X]));
335359 }
336360 }
337361 epoch_of_next_negative_sample[row] =
@@ -373,7 +397,7 @@ void call_optimize_batch_kernel(T const* head_embedding,
373397 cudaStream_t& stream,
374398 T rounding)
375399{
376- std::size_t requiredSize = TPB_X * params->n_components ;
400+ std::size_t requiredSize = TPB_X * params->n_components * 2 ;
377401 requiredSize *= sizeof (T);
378402 bool use_shared_mem = requiredSize < static_cast <std::size_t >(raft::getSharedMemPerBlock ());
379403 T nsr_inv = T (1.0 ) / params->negative_sample_rate ;
0 commit comments