Skip to content

Commit aca6d1c

Browse files
committed
simplify implementation
partially by making use of JuliaGPU/GPUCompiler.jl#727
1 parent a767801 commit aca6d1c

File tree

6 files changed

+72
-76
lines changed

6 files changed

+72
-76
lines changed

.github/workflows/Test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ jobs:
136136
julia --project -e '
137137
using Pkg
138138
Pkg.develop(path="lib/intrinsics")
139-
Pkg.add(name="GPUCompiler", rev="tb/kernel_state_reference")'
139+
Pkg.add(name="GPUCompiler", rev="tb/linked_module")'
140140
141141
- name: Test OpenCL.jl
142142
uses: julia-actions/julia-runtest@v1

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ SPIRV_Tools_jll = "6ac6d60f-d740-5983-97d7-a4482c0689f4"
2222
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2323

2424
[sources]
25-
GPUCompiler = {rev = "tb/kernel_state_reference", url = "https://github.com/JuliaGPU/GPUCompiler.jl.git"}
25+
GPUCompiler = {rev = "tb/linked_module", url = "https://github.com/JuliaGPU/GPUCompiler.jl.git"}
2626
SPIRVIntrinsics = {path = "lib/intrinsics"}
2727

2828
[compat]

src/compiler/compilation.jl

Lines changed: 62 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) =
1818
in(fn, known_intrinsics) ||
1919
contains(fn, "__spirv_")
2020

21-
GPUCompiler.kernel_state_type(job::OpenCLCompilerJob) = KernelState
21+
GPUCompiler.kernel_state_type(::OpenCLCompilerJob) = KernelState
2222

2323
function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
2424
mod::LLVM.Module, entry::LLVM.Function)
@@ -28,71 +28,72 @@ function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
2828

2929
# if this kernel uses our RNG, we should prime the shared state.
3030
# XXX: these transformations should really happen at the Julia IR level...
31-
if haskey(functions(mod), "julia.spirv.random_keys")
31+
if haskey(functions(mod), "julia.spirv.random_keys") && job.config.kernel
32+
# insert call to `initialize_rng_state`
33+
f = initialize_rng_state
34+
ft = typeof(f)
35+
tt = Tuple{}
36+
37+
# create a deferred compilation job for `initialize_rng_state`
38+
src = methodinstance(ft, tt, GPUCompiler.tls_world_age())
39+
cfg = CompilerConfig(job.config; kernel=false, name=nothing)
40+
job = CompilerJob(src, cfg, job.world)
41+
id = length(GPUCompiler.deferred_codegen_jobs) + 1
42+
GPUCompiler.deferred_codegen_jobs[id] = job
43+
44+
# generate IR for calls to `deferred_codegen` and the resulting function pointer
45+
top_bb = first(blocks(entry))
46+
bb = BasicBlock(top_bb, "initialize_rng")
47+
@dispose builder=IRBuilder() begin
48+
position!(builder, bb)
49+
subprogram = LLVM.subprogram(entry)
50+
if subprogram !== nothing
51+
loc = DILocation(0, 0, subprogram)
52+
debuglocation!(builder, loc)
53+
end
54+
debuglocation!(builder, first(instructions(top_bb)))
55+
56+
# call the `deferred_codegen` marker function
57+
T_ptr = if LLVM.version() >= v"17"
58+
LLVM.PointerType()
59+
elseif VERSION >= v"1.12.0-DEV.225"
60+
LLVM.PointerType(LLVM.Int8Type())
61+
else
62+
LLVM.Int64Type()
63+
end
64+
T_id = convert(LLVMType, Int)
65+
deferred_codegen_ft = LLVM.FunctionType(T_ptr, [T_id])
66+
deferred_codegen = if haskey(functions(mod), "deferred_codegen")
67+
functions(mod)["deferred_codegen"]
68+
else
69+
LLVM.Function(mod, "deferred_codegen", deferred_codegen_ft)
70+
end
71+
fptr = call!(builder, deferred_codegen_ft, deferred_codegen, [ConstantInt(id)])
72+
73+
# call the `initialize_rng_state` function
74+
rt = Core.Compiler.return_type(f, tt)
75+
llvm_rt = convert(LLVMType, rt)
76+
llvm_ft = LLVM.FunctionType(llvm_rt)
77+
fptr = inttoptr!(builder, fptr, LLVM.PointerType(llvm_ft))
78+
call!(builder, llvm_ft, fptr)
79+
br!(builder, top_bb)
80+
end
81+
82+
# XXX: put some of the above behind GPUCompiler abstractions
83+
# (e.g., a compile-time version of `deferred_codegen`)
84+
end
85+
return entry
86+
end
87+
88+
function GPUCompiler.finish_linked_module!(@nospecialize(job::OpenCLCompilerJob), mod::LLVM.Module)
89+
for f in GPUCompiler.kernels(mod)
3290
kernel_intrinsics = Dict(
3391
"julia.spirv.random_keys" => (; name = "random_keys", typ = LLVMPtr{UInt32, AS.Workgroup}),
3492
"julia.spirv.random_counters" => (; name = "random_counters", typ = LLVMPtr{UInt32, AS.Workgroup}),
3593
)
36-
entry = GPUCompiler.add_input_arguments!(job, mod, entry, kernel_intrinsics)
37-
38-
if job.config.kernel
39-
# insert call to `initialize_rng_state`
40-
f = initialize_rng_state
41-
ft = typeof(f)
42-
tt = NTuple{2, LLVMPtr{UInt32, AS.Workgroup}}
43-
44-
# create a deferred compilation job for `initialize_rng_state`
45-
src = methodinstance(ft, tt, GPUCompiler.tls_world_age())
46-
cfg = CompilerConfig(job.config; kernel=false, name=nothing)
47-
job = CompilerJob(src, cfg, job.world)
48-
id = length(GPUCompiler.deferred_codegen_jobs) + 1
49-
GPUCompiler.deferred_codegen_jobs[id] = job
50-
51-
# generate IR for calls to `deferred_codegen` and the resulting function pointer
52-
top_bb = first(blocks(entry))
53-
bb = BasicBlock(top_bb, "initialize_rng")
54-
@dispose builder=IRBuilder() begin
55-
position!(builder, bb)
56-
subprogram = LLVM.subprogram(entry)
57-
if subprogram !== nothing
58-
loc = DILocation(0, 0, subprogram)
59-
debuglocation!(builder, loc)
60-
end
61-
debuglocation!(builder, first(instructions(top_bb)))
62-
63-
# call the `deferred_codegen` marker function
64-
T_ptr = if LLVM.version() >= v"17"
65-
LLVM.PointerType()
66-
elseif VERSION >= v"1.12.0-DEV.225"
67-
LLVM.PointerType(LLVM.Int8Type())
68-
else
69-
LLVM.Int64Type()
70-
end
71-
T_id = convert(LLVMType, Int)
72-
deferred_codegen_ft = LLVM.FunctionType(T_ptr, [T_id])
73-
deferred_codegen = if haskey(functions(mod), "deferred_codegen")
74-
functions(mod)["deferred_codegen"]
75-
else
76-
LLVM.Function(mod, "deferred_codegen", deferred_codegen_ft)
77-
end
78-
fptr = call!(builder, deferred_codegen_ft, deferred_codegen, [ConstantInt(id)])
79-
80-
# call the `initialize_rng_state` function
81-
rt = Core.Compiler.return_type(f, tt)
82-
llvm_rt = convert(LLVMType, rt)
83-
llvm_ft = LLVM.FunctionType(llvm_rt, [convert(LLVMType, LLVMPtr{UInt32, AS.Workgroup}) for _ in 1:2])
84-
fptr = inttoptr!(builder, fptr, LLVM.PointerType(llvm_ft))
85-
random_keys = findfirst(arg -> name(arg) == "random_keys", parameters(entry))
86-
random_counters = findfirst(arg -> name(arg) == "random_counters", parameters(entry))
87-
call!(builder, llvm_ft, fptr, parameters(entry)[[random_keys, random_counters]])
88-
br!(builder, top_bb)
89-
end
90-
91-
# XXX: put some of the above behind GPUCompiler abstractions
92-
# (e.g., a compile-time version of `deferred_codegen`)
93-
end
94+
GPUCompiler.add_input_arguments!(job, mod, f, kernel_intrinsics)
9495
end
95-
return entry
96+
return
9697
end
9798

9899
## compiler implementation (cache, configure, compile, and link)

src/device/random.jl

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,10 @@ end
2222
end
2323

2424
# initialization function, called automatically at the start of each kernel
25-
function initialize_rng_state(random_keys_ptr::LLVMPtr{UInt32, AS.Workgroup}, random_counters_ptr::LLVMPtr{UInt32, AS.Workgroup})
26-
n = get_num_sub_groups()
27-
random_keys = CLDeviceArray{UInt32, 1, AS.Workgroup}((n,), random_keys_ptr)
28-
random_counters = CLDeviceArray{UInt32, 1, AS.Workgroup}((n,), random_counters_ptr)
29-
25+
function initialize_rng_state()
3026
subgroup_id = get_sub_group_id()
31-
@inbounds random_keys[subgroup_id] = kernel_state().random_seed
32-
@inbounds random_counters[subgroup_id] = 0
27+
@inbounds global_random_keys()[subgroup_id] = kernel_state().random_seed
28+
@inbounds global_random_counters()[subgroup_id] = 0
3329
end
3430

3531
# generators
@@ -50,10 +46,8 @@ struct Philox2x32{R} <: RandomNumbers.AbstractRNG{UInt64} end
5046
elseif field === :ctr1
5147
@inbounds global_random_counters()[subgroup_id]
5248
elseif field === :ctr2
53-
global_id = get_global_id(1) + (get_global_id(2) - Int32(1)) * get_global_size(1) +
54-
(get_global_id(3) - Int32(1)) * get_global_size(1) * get_global_size(2)
55-
global_id % UInt32
56-
end::UInt32
49+
unsafe_trunc(UInt32, get_global_linear_id())
50+
end
5751
end
5852

5953
@inline function Base.setproperty!(rng::Philox2x32, field::Symbol, x)
@@ -64,6 +58,7 @@ end
6458
elseif field === :ctr1
6559
@inbounds global_random_counters()[subgroup_id] = x
6660
end
61+
return rng
6762
end
6863

6964
@device_override @inline Random.default_rng() = Philox2x32()

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2424
pocl_jll = "627d6b7a-bbe6-5189-83e7-98cc0a5aeadd"
2525

2626
[sources]
27-
GPUCompiler = {rev = "tb/kernel_state_reference", url = "https://github.com/JuliaGPU/GPUCompiler.jl.git"}
27+
GPUCompiler = {rev = "tb/linked_module", url = "https://github.com/JuliaGPU/GPUCompiler.jl.git"}
2828

2929
[compat]
3030
pocl_jll = "7.0"

test/device/random.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ function apply_seed(seed)
1818
end
1919
end
2020

21-
eltypes = [filter(x -> !(x <: Complex), GPUArraysTestSuite.supported_eltypes(CLArray)); UInt16; UInt32; UInt64]
21+
eltypes = [filter(x -> !(x <: Complex), GPUArraysTestSuite.supported_eltypes(CLArray)); Bool; UInt16; UInt32; UInt64]
2222

2323
@testset "rand($T), seed $seed" for T in eltypes, seed in (nothing, #=missing,=# 1234)
2424
# different kernel invocations should get different numbers

0 commit comments

Comments
 (0)