Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions lib/CUDAKernels/src/CUDAKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ end

import KernelAbstractions: Event, CPUEvent, NoneEvent, MultiEvent, CPU, GPU, isdone, failed

struct CUDADevice <: GPU end
struct CUDADevice{PreferBlocks} <: GPU end
CUDADevice() = CUDADevice{false}()

struct CudaEvent <: Event
event::CUDA.CuEvent
Expand Down Expand Up @@ -232,7 +233,7 @@ import KernelAbstractions: Kernel, StaticSize, DynamicSize, partition, blocks, w
###
# Kernel launch
###
function launch_config(kernel::Kernel{CUDADevice}, ndrange, workgroupsize)
function launch_config(kernel::Kernel{<:CUDADevice}, ndrange, workgroupsize)
if ndrange isa Integer
ndrange = (ndrange,)
end
Expand Down Expand Up @@ -265,7 +266,7 @@ function threads_to_workgroupsize(threads, ndrange)
end
end

function (obj::Kernel{CUDADevice})(args...; ndrange=nothing, dependencies=Event(CUDADevice()), workgroupsize=nothing, progress=yield)
function (obj::Kernel{CUDADevice{PreferBlocks}})(args...; ndrange=nothing, dependencies=Event(CUDADevice()), workgroupsize=nothing, progress=yield) where PreferBlocks

ndrange, workgroupsize, iterspace, dynamic = launch_config(obj, ndrange, workgroupsize)
# this might not be the final context, since we may tune the workgroupsize
Expand All @@ -275,7 +276,17 @@ function (obj::Kernel{CUDADevice})(args...; ndrange=nothing, dependencies=Event(
# figure out the optimal workgroupsize automatically
if KernelAbstractions.workgroupsize(obj) <: DynamicSize && workgroupsize === nothing
config = CUDA.launch_configuration(kernel.fun; max_threads=prod(ndrange))
workgroupsize = threads_to_workgroupsize(config.threads, ndrange)
if PreferBlocks
# Prefer blocks over threads
threads = min(prod(ndrange), config.threads)
# XXX: Some kernels performs much better with all blocks active
cu_blocks = max(cld(prod(ndrange), threads), config.blocks)
threads = cld(prod(ndrange), cu_blocks)
else
threads = config.threads
end

workgroupsize = threads_to_workgroupsize(threads, ndrange)
iterspace, dynamic = partition(obj, ndrange, workgroupsize)
ctx = mkcontext(obj, ndrange, iterspace)
end
Expand Down Expand Up @@ -311,7 +322,7 @@ import KernelAbstractions: CompilerMetadata, DynamicCheck, LinearIndices
import KernelAbstractions: __index_Local_Linear, __index_Group_Linear, __index_Global_Linear, __index_Local_Cartesian, __index_Group_Cartesian, __index_Global_Cartesian, __validindex, __print
import KernelAbstractions: mkcontext, expand, __iterspace, __ndrange, __dynamic_checkbounds

function mkcontext(kernel::Kernel{CUDADevice}, _ndrange, iterspace)
function mkcontext(kernel::Kernel{<:CUDADevice}, _ndrange, iterspace)
CompilerMetadata{KernelAbstractions.ndrange(kernel), DynamicCheck}(_ndrange, iterspace)
end

Expand Down Expand Up @@ -393,6 +404,6 @@ end
Adapt.adapt_storage(to::ConstAdaptor, a::CUDA.CuDeviceArray) = Base.Experimental.Const(a)

# Argument conversion
KernelAbstractions.argconvert(k::Kernel{CUDADevice}, arg) = CUDA.cudaconvert(arg)
KernelAbstractions.argconvert(k::Kernel{<:CUDADevice}, arg) = CUDA.cudaconvert(arg)

end
3 changes: 2 additions & 1 deletion lib/CUDAKernels/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ end
if CUDA.functional()
CUDA.versioninfo()
CUDA.allowscalar(false)
Testsuite.testsuite(CUDADevice, backend, CUDA, CuArray, CUDA.CuDeviceArray)
Testsuite.testsuite(CUDADevice{false}, backend, CUDA, CuArray, CUDA.CuDeviceArray)
Testsuite.testsuite(CUDADevice{true}, backend, CUDA, CuArray, CUDA.CuDeviceArray)
# GradientsTestsuite.testsuite(CUDADevice, backend, CUDA, CuArray, CUDA.CuDeviceArray)
elseif !CI
error("No CUDA GPUs available!")
Expand Down
15 changes: 10 additions & 5 deletions test/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,20 @@ end
x = ArrayT(rand(Float32, 5))
A = ArrayT(rand(Float32, 5,5))
device = backend()
@test @inferred(KernelAbstractions.get_device(A)) == device
@test @inferred(KernelAbstractions.get_device(view(A, 2:4, 1:3))) == device
if !(isdefined(Main, :CUDAKernels) && (device isa Main.ROCKernels.CUDADevice)
deviceT = Main.ROCKernels.CUDADevice
else
deviceT = typeof(device)
end
@test @inferred(KernelAbstractions.get_device(A)) isa deviceT
@test @inferred(KernelAbstractions.get_device(view(A, 2:4, 1:3))) isa deviceT
if !(isdefined(Main, :ROCKernels) && (device isa Main.ROCKernels.ROCDevice)) &&
!(isdefined(Main, :oneAPIKernels) && (device isa Main.oneAPIKernels.oneAPIDevice))
# Sparse arrays are not supported by the ROCm or oneAPI backends yet:
@test @inferred(KernelAbstractions.get_device(sparse(A))) == device
@test @inferred(KernelAbstractions.get_device(sparse(A))) isa deviceT
end
@test @inferred(KernelAbstractions.get_device(Diagonal(x))) == device
@test @inferred(KernelAbstractions.get_device(Tridiagonal(A))) == device
@test @inferred(KernelAbstractions.get_device(Diagonal(x))) isa deviceT
@test @inferred(KernelAbstractions.get_device(Tridiagonal(A))) isa deviceT
end

@testset "indextest" begin
Expand Down