@@ -232,13 +232,42 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::AnyCuArray{T},
232232 reduce_threads = compute_threads (kernel_config. threads)
233233 reduce_shmem = compute_shmem (reduce_threads)
234234
235+ # how many blocks should we launch?
236+ #
237+ # even though we can always reduce each slice in a single thread block, that may not be
238+ # optimal as it might not saturate the GPU. we already launch some blocks to process
239+ # independent dimensions in parallel; pad that number to ensure full occupancy.
240+ other_blocks = length (Rother)
241+ reduce_blocks = if other_blocks >= kernel_config. blocks
242+ 1
243+ else
244+ min (cld (length (Rreduce), reduce_threads), # how many we need at most
245+ cld (kernel_config. blocks, other_blocks)) # maximize occupancy
246+ end
247+
235248 # determine the launch configuration
236249 threads = reduce_threads
237250 shmem = reduce_shmem
238- blocks = length (Rother)
251+ blocks = reduce_blocks * other_blocks
239252
240253 # perform the actual reduction
241- kernel (f, op, init, Rreduce, Rother, Val (shuffle), R, A; threads, blocks, shmem)
254+ if reduce_blocks == 1
255+ # we can cover the dimensions to reduce using a single block
256+ kernel (f, op, init, Rreduce, Rother, Val (shuffle), R, A; threads, blocks, shmem)
257+ else
258+ # we need multiple steps to cover all values to reduce
259+ partial = similar (R, (size (R)... , reduce_blocks))
260+ if init === nothing
261+ # without an explicit initializer we need to copy from the output container
262+ partial .= R
263+ end
264+ # NOTE: we can't use the previously-compiled kernel, since the type of `partial`
265+ # might not match the original output container (e.g. if that was a view).
266+ @cuda (threads, blocks, shmem,
267+ partial_mapreduce_grid (f, op, init, Rreduce, Rother, Val (shuffle), partial, A))
268+
269+ GPUArrays. mapreducedim! (identity, op, R, partial; init= init)
270+ end
242271
243272 return R
244273end
0 commit comments