@@ -257,14 +257,37 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::AnyCuArray{T},
257257 else
258258 # we need multiple steps to cover all values to reduce
259259 partial = similar (R, (size (R)... , reduce_blocks))
260+
261+ # NOTE: we can't use the previously-compiled kernel, since the type of `partial`
262+ # might not match the original output container (e.g. if that was a view).
263+ # recalculate kernel configuration for the partial array
264+ partial_kernel = @cuda launch= false partial_mapreduce_grid (f, op, init, Rreduce, Rother, Val (shuffle), partial, A)
265+ partial_kernel_config = launch_configuration (partial_kernel. fun; shmem= compute_shmem∘ compute_threads)
266+ partial_reduce_threads = compute_threads (partial_kernel_config. threads)
267+ partial_reduce_shmem = compute_shmem (partial_reduce_threads)
268+
269+ # recalculate blocks based on the new thread count
270+ partial_reduce_blocks = if other_blocks >= partial_kernel_config. blocks
271+ 1
272+ else
273+ min (cld (length (Rreduce), partial_reduce_threads), # how many we need at most
274+ cld (partial_kernel_config. blocks, other_blocks)) # maximize occupancy
275+ end
276+
277+ partial_threads = partial_reduce_threads
278+ partial_shmem = partial_reduce_shmem
279+ partial_blocks = partial_reduce_blocks* other_blocks
280+
281+ if reduce_blocks != partial_blocks
282+ partial = similar (R, (size (R)... , partial_blocks))
283+ end
284+
260285 if init === nothing
261286 # without an explicit initializer we need to copy from the output container
262287 partial .= R
263288 end
264- # NOTE: we can't use the previously-compiled kernel, since the type of `partial`
265- # might not match the original output container (e.g. if that was a view).
266- @cuda (threads, blocks, shmem,
267- partial_mapreduce_grid (f, op, init, Rreduce, Rother, Val (shuffle), partial, A))
289+
290+ partial_kernel (f, op, init, Rreduce, Rother, Val (shuffle), partial, A; threads= partial_threads, blocks= partial_blocks, shmem= partial_shmem)
268291
269292 GPUArrays. mapreducedim! (identity, op, R, partial; init= init)
270293 end
0 commit comments