Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[sources]
SPIRVIntrinsics = {path = "lib/intrinsics"}
KernelAbstractions = {rev = "main", url = "https://github.com/JuliaGPU/KernelAbstractions.jl"}

[compat]
Adapt = "4"
GPUArrays = "11.2.1"
GPUCompiler = "1.7.1"
KernelAbstractions = "0.9.38"
KernelAbstractions = "0.9, 0.10"
LLVM = "9.1"
LinearAlgebra = "1"
OpenCL_jll = "=2024.10.24"
Expand Down
68 changes: 49 additions & 19 deletions src/OpenCLKernels.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
module OpenCLKernels

using ..OpenCL
using ..OpenCL: @device_override, method_table
using ..OpenCL: @device_override, method_table, kernel_convert, clfunction

import KernelAbstractions as KA
import KernelAbstractions.KernelIntrinsics as KI

import StaticArrays

Expand Down Expand Up @@ -126,33 +127,62 @@ function (obj::KA.Kernel{OpenCLBackend})(args...; ndrange=nothing, workgroupsize
return nothing
end

KI.argconvert(::OpenCLBackend, arg) = kernel_convert(arg)

function KI.kernel_function(::OpenCLBackend, f::F, tt::TT=Tuple{}; name = nothing, kwargs...) where {F,TT}
kern = clfunction(f, tt; name, kwargs...)
KI.Kernel{OpenCLBackend, typeof(kern)}(OpenCLBackend(), kern)
end

function (obj::KI.Kernel{OpenCLBackend})(args...; numworkgroups = 1, workgroupsize = 1)
KI.check_launch_args(numworkgroups, workgroupsize)

local_size = (workgroupsize..., ntuple(_ -> 1, 3 - length(workgroupsize))...)

numworkgroups = (numworkgroups..., ntuple(_ -> 1, 3 - length(numworkgroups))...)
global_size = local_size .* numworkgroups

obj.kern(args...; local_size, global_size)
return nothing
end


function KI.kernel_max_work_group_size(kernel::KI.Kernel{<:OpenCLBackend}; max_work_items::Int=typemax(Int))::Int
wginfo = cl.work_group_info(kernel.kern.fun, cl.device())
Int(min(wginfo.size, max_work_items))
end
function KI.max_work_group_size(::OpenCLBackend)::Int
Int(cl.device().max_work_group_size)
end
function KI.multiprocessor_count(::OpenCLBackend)::Int
Int(cl.device().max_compute_units)
end

## Indexing Functions
## COV_EXCL_START

@device_override @inline function KA.__index_Local_Linear(ctx)
return get_local_id(1)
@device_override @inline function KI.get_local_id()
return (; x = Int(get_local_id(1)), y = Int(get_local_id(2)), z = Int(get_local_id(3)))
end

@device_override @inline function KA.__index_Group_Linear(ctx)
return get_group_id(1)
@device_override @inline function KI.get_group_id()
return (; x = Int(get_group_id(1)), y = Int(get_group_id(2)), z = Int(get_group_id(3)))
end

@device_override @inline function KA.__index_Global_Linear(ctx)
#return get_global_id(1) # JuliaGPU/OpenCL.jl#346
I = KA.__index_Global_Cartesian(ctx)
@inbounds LinearIndices(KA.__ndrange(ctx))[I]
@device_override @inline function KI.get_global_id()
return (; x = Int(get_global_id(1)), y = Int(get_global_id(2)), z = Int(get_global_id(3)))
end

@device_override @inline function KA.__index_Local_Cartesian(ctx)
@inbounds KA.workitems(KA.__iterspace(ctx))[get_local_id(1)]
@device_override @inline function KI.get_local_size()
return (; x = Int(get_local_size(1)), y = Int(get_local_size(2)), z = Int(get_local_size(3)))
end

@device_override @inline function KA.__index_Group_Cartesian(ctx)
@inbounds KA.blocks(KA.__iterspace(ctx))[get_group_id(1)]
@device_override @inline function KI.get_num_groups()
return (; x = Int(get_num_groups(1)), y = Int(get_num_groups(2)), z = Int(get_num_groups(3)))
end

@device_override @inline function KA.__index_Global_Cartesian(ctx)
return @inbounds KA.expand(KA.__iterspace(ctx), get_group_id(1), get_local_id(1))
@device_override @inline function KI.get_global_size()
return (; x = Int(get_global_size(1)), y = Int(get_global_size(2)), z = Int(get_global_size(3)))
end

@device_override @inline function KA.__validindex(ctx)
Expand All @@ -167,7 +197,7 @@ end

## Shared and Scratch Memory

@device_override @inline function KA.SharedMemory(::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
@device_override @inline function KI.localmemory(::Type{T}, ::Val{Dims}) where {T, Dims}
ptr = OpenCL.emit_localmemory(T, Val(prod(Dims)))
CLDeviceArray(Dims, ptr)
end
Expand All @@ -179,14 +209,14 @@ end

## Synchronization and Printing

@device_override @inline function KA.__synchronize()
@device_override @inline function KI.barrier()
work_group_barrier(OpenCL.LOCAL_MEM_FENCE | OpenCL.GLOBAL_MEM_FENCE)
end

@device_override @inline function KA.__print(args...)
@device_override @inline function KI._print(args...)
OpenCL._print(args...)
end

## COV_EXCL_STOP

## Other

Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Expand Down
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
@static if VERSION < v"1.11"
using Pkg
Pkg.add(url="https://github.com/JuliaGPU/KernelAbstractions.jl", rev="main")
end

using Distributed
using Dates
import REPL
Expand Down
Loading