@@ -18,7 +18,7 @@ GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) =
1818 in (fn, known_intrinsics) ||
1919 contains (fn, " __spirv_" )
2020
21- GPUCompiler. kernel_state_type (job :: OpenCLCompilerJob ) = KernelState
21+ GPUCompiler. kernel_state_type (:: OpenCLCompilerJob ) = KernelState
2222
2323function GPUCompiler. finish_module! (@nospecialize (job:: OpenCLCompilerJob ),
2424 mod:: LLVM.Module , entry:: LLVM.Function )
@@ -28,71 +28,72 @@ function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
2828
2929 # if this kernel uses our RNG, we should prime the shared state.
3030 # XXX : these transformations should really happen at the Julia IR level...
31- if haskey (functions (mod), " julia.spirv.random_keys" )
31+ if haskey (functions (mod), " julia.spirv.random_keys" ) && job. config. kernel
32+ # insert call to `initialize_rng_state`
33+ f = initialize_rng_state
34+ ft = typeof (f)
35+ tt = Tuple{}
36+
37+ # create a deferred compilation job for `initialize_rng_state`
38+ src = methodinstance (ft, tt, GPUCompiler. tls_world_age ())
39+ cfg = CompilerConfig (job. config; kernel= false , name= nothing )
40+ job = CompilerJob (src, cfg, job. world)
41+ id = length (GPUCompiler. deferred_codegen_jobs) + 1
42+ GPUCompiler. deferred_codegen_jobs[id] = job
43+
44+ # generate IR for calls to `deferred_codegen` and the resulting function pointer
45+ top_bb = first (blocks (entry))
46+ bb = BasicBlock (top_bb, " initialize_rng" )
47+ @dispose builder= IRBuilder () begin
48+ position! (builder, bb)
49+ subprogram = LLVM. subprogram (entry)
50+ if subprogram != = nothing
51+ loc = DILocation (0 , 0 , subprogram)
52+ debuglocation! (builder, loc)
53+ end
54+ debuglocation! (builder, first (instructions (top_bb)))
55+
56+ # call the `deferred_codegen` marker function
57+ T_ptr = if LLVM. version () >= v " 17"
58+ LLVM. PointerType ()
59+ elseif VERSION >= v " 1.12.0-DEV.225"
60+ LLVM. PointerType (LLVM. Int8Type ())
61+ else
62+ LLVM. Int64Type ()
63+ end
64+ T_id = convert (LLVMType, Int)
65+ deferred_codegen_ft = LLVM. FunctionType (T_ptr, [T_id])
66+ deferred_codegen = if haskey (functions (mod), " deferred_codegen" )
67+ functions (mod)[" deferred_codegen" ]
68+ else
69+ LLVM. Function (mod, " deferred_codegen" , deferred_codegen_ft)
70+ end
71+ fptr = call! (builder, deferred_codegen_ft, deferred_codegen, [ConstantInt (id)])
72+
73+ # call the `initialize_rng_state` function
74+ rt = Core. Compiler. return_type (f, tt)
75+ llvm_rt = convert (LLVMType, rt)
76+ llvm_ft = LLVM. FunctionType (llvm_rt)
77+ fptr = inttoptr! (builder, fptr, LLVM. PointerType (llvm_ft))
78+ call! (builder, llvm_ft, fptr)
79+ br! (builder, top_bb)
80+ end
81+
82+ # XXX : put some of the above behind GPUCompiler abstractions
83+ # (e.g., a compile-time version of `deferred_codegen`)
84+ end
85+ return entry
86+ end
87+
88+ function GPUCompiler. finish_linked_module! (@nospecialize (job:: OpenCLCompilerJob ), mod:: LLVM.Module )
89+ for f in GPUCompiler. kernels (mod)
3290 kernel_intrinsics = Dict (
3391 " julia.spirv.random_keys" => (; name = " random_keys" , typ = LLVMPtr{UInt32, AS. Workgroup}),
3492 " julia.spirv.random_counters" => (; name = " random_counters" , typ = LLVMPtr{UInt32, AS. Workgroup}),
3593 )
36- entry = GPUCompiler. add_input_arguments! (job, mod, entry, kernel_intrinsics)
37-
38- if job. config. kernel
39- # insert call to `initialize_rng_state`
40- f = initialize_rng_state
41- ft = typeof (f)
42- tt = NTuple{2 , LLVMPtr{UInt32, AS. Workgroup}}
43-
44- # create a deferred compilation job for `initialize_rng_state`
45- src = methodinstance (ft, tt, GPUCompiler. tls_world_age ())
46- cfg = CompilerConfig (job. config; kernel= false , name= nothing )
47- job = CompilerJob (src, cfg, job. world)
48- id = length (GPUCompiler. deferred_codegen_jobs) + 1
49- GPUCompiler. deferred_codegen_jobs[id] = job
50-
51- # generate IR for calls to `deferred_codegen` and the resulting function pointer
52- top_bb = first (blocks (entry))
53- bb = BasicBlock (top_bb, " initialize_rng" )
54- @dispose builder= IRBuilder () begin
55- position! (builder, bb)
56- subprogram = LLVM. subprogram (entry)
57- if subprogram != = nothing
58- loc = DILocation (0 , 0 , subprogram)
59- debuglocation! (builder, loc)
60- end
61- debuglocation! (builder, first (instructions (top_bb)))
62-
63- # call the `deferred_codegen` marker function
64- T_ptr = if LLVM. version () >= v " 17"
65- LLVM. PointerType ()
66- elseif VERSION >= v " 1.12.0-DEV.225"
67- LLVM. PointerType (LLVM. Int8Type ())
68- else
69- LLVM. Int64Type ()
70- end
71- T_id = convert (LLVMType, Int)
72- deferred_codegen_ft = LLVM. FunctionType (T_ptr, [T_id])
73- deferred_codegen = if haskey (functions (mod), " deferred_codegen" )
74- functions (mod)[" deferred_codegen" ]
75- else
76- LLVM. Function (mod, " deferred_codegen" , deferred_codegen_ft)
77- end
78- fptr = call! (builder, deferred_codegen_ft, deferred_codegen, [ConstantInt (id)])
79-
80- # call the `initialize_rng_state` function
81- rt = Core. Compiler. return_type (f, tt)
82- llvm_rt = convert (LLVMType, rt)
83- llvm_ft = LLVM. FunctionType (llvm_rt, [convert (LLVMType, LLVMPtr{UInt32, AS. Workgroup}) for _ in 1 : 2 ])
84- fptr = inttoptr! (builder, fptr, LLVM. PointerType (llvm_ft))
85- random_keys = findfirst (arg -> name (arg) == " random_keys" , parameters (entry))
86- random_counters = findfirst (arg -> name (arg) == " random_counters" , parameters (entry))
87- call! (builder, llvm_ft, fptr, parameters (entry)[[random_keys, random_counters]])
88- br! (builder, top_bb)
89- end
90-
91- # XXX : put some of the above behind GPUCompiler abstractions
92- # (e.g., a compile-time version of `deferred_codegen`)
93- end
94+ GPUCompiler. add_input_arguments! (job, mod, f, kernel_intrinsics)
9495 end
95- return entry
96+ return
9697end
9798
9899# # compiler implementation (cache, configure, compile, and link)
0 commit comments