From 5e66b9577ad722624949ea330f8af6fbebab5061 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 13 Oct 2025 14:52:19 +0000 Subject: [PATCH] speed up scalars --- mlx/backend/cuda/allocator.cpp | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 329906a13f..e1507c631a 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -30,15 +30,20 @@ SmallSizePool::SmallSizePool() { next_free_ = buffer_; CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size)); + + int device_count = 0; + CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count)); + for (int i = 0; i < device_count; ++i) { #if CUDART_VERSION >= 13000 - cudaMemLocation loc; - loc.type = cudaMemLocationTypeDevice; - loc.id = 0; + cudaMemLocation loc; + loc.type = cudaMemLocationTypeDevice; + loc.id = i; #else - int loc = 0; + int loc = i; #endif // CUDART_VERSION >= 13000 - CHECK_CUDA_ERROR( - cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, loc)); + CHECK_CUDA_ERROR( + cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc)); + } auto curr = next_free_; for (size_t i = 1; i < num_blocks; ++i) {