diff --git a/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh b/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh index 98234388e5..33abacf588 100644 --- a/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh +++ b/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh @@ -108,8 +108,8 @@ CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding, T nsr_inv, T rounding) { - nnz_t row = (blockIdx.x * static_cast(TPB_X)) + threadIdx.x; - nnz_t skip_size = blockDim.x * gridDim.x; + size_t row = (static_cast(blockIdx.x) * static_cast(TPB_X)) + threadIdx.x; + size_t skip_size = static_cast(blockDim.x) * gridDim.x; T current_reg[n_components], other_reg[n_components], grads[n_components]; while (row < nnz) { @@ -231,8 +231,8 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, T rounding) { extern __shared__ T embedding_shared_mem_updates[]; - nnz_t row = (blockIdx.x * static_cast(TPB_X)) + threadIdx.x; - nnz_t skip_size = blockDim.x * gridDim.x; + size_t row = (static_cast(blockIdx.x) * static_cast(TPB_X)) + threadIdx.x; + size_t skip_size = static_cast(blockDim.x) * gridDim.x; while (row < nnz) { auto _epoch_of_next_sample = epoch_of_next_sample[row];