Skip to content

Commit 9b14634

Browse files
committed
Recalculate threads and blocks in reduction block optimization
1 parent 7edf346 commit 9b14634

1 file changed

Lines changed: 27 additions & 4 deletions

File tree

src/mapreduce.jl

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,37 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::AnyCuArray{T},
257257
else
258258
# we need multiple steps to cover all values to reduce
259259
partial = similar(R, (size(R)..., reduce_blocks))
260+
261+
# NOTE: we can't use the previously-compiled kernel, since the type of `partial`
262+
# might not match the original output container (e.g. if that was a view).
263+
# recalculate kernel configuration for the partial array
264+
partial_kernel = @cuda launch=false partial_mapreduce_grid(f, op, init, Rreduce, Rother, Val(shuffle), partial, A)
265+
partial_kernel_config = launch_configuration(partial_kernel.fun; shmem=compute_shmemcompute_threads)
266+
partial_reduce_threads = compute_threads(partial_kernel_config.threads)
267+
partial_reduce_shmem = compute_shmem(partial_reduce_threads)
268+
269+
# recalculate blocks based on the new thread count
270+
partial_reduce_blocks = if other_blocks >= partial_kernel_config.blocks
271+
1
272+
else
273+
min(cld(length(Rreduce), partial_reduce_threads), # how many we need at most
274+
cld(partial_kernel_config.blocks, other_blocks)) # maximize occupancy
275+
end
276+
277+
partial_threads = partial_reduce_threads
278+
partial_shmem = partial_reduce_shmem
279+
partial_blocks = partial_reduce_blocks*other_blocks
280+
281+
if reduce_blocks != partial_blocks
282+
partial = similar(R, (size(R)..., partial_blocks))
283+
end
284+
260285
if init === nothing
261286
# without an explicit initializer we need to copy from the output container
262287
partial .= R
263288
end
264-
# NOTE: we can't use the previously-compiled kernel, since the type of `partial`
265-
# might not match the original output container (e.g. if that was a view).
266-
@cuda(threads, blocks, shmem,
267-
partial_mapreduce_grid(f, op, init, Rreduce, Rother, Val(shuffle), partial, A))
289+
290+
partial_kernel(f, op, init, Rreduce, Rother, Val(shuffle), partial, A; threads=partial_threads, blocks=partial_blocks, shmem=partial_shmem)
268291

269292
GPUArrays.mapreducedim!(identity, op, R, partial; init=init)
270293
end

0 commit comments

Comments
 (0)