Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6626,19 +6626,31 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri
mstr = if job.config.params.ABI <: InlineABI
""
else
fixup_callconv!(mod, JIT.get_tm())
for f in functions(mod)
for i in 1:length(parameters(f))
for a in collect(parameter_attributes(f, i))
if kind(a) == "enzyme_sret"
API.EnzymeDumpValueRef(f)
end
@assert kind(a) != "enzyme_sret"
@assert kind(a) != "enzyme_sret_v"
end
end
end
string(mod)
end
if job.config.params.ABI <: FFIABI || job.config.params.ABI <: NonGenABI
if DumpPrePostOpt[]
API.EnzymeDumpModuleRef(mod.ref)
end
post_optimize!(mod, JIT.get_tm())
post_optimize!(mod, JIT.get_tm(); callconv=false)
if DumpPostOpt[]
API.EnzymeDumpModuleRef(mod.ref)
end
else
propagate_returned!(mod)
Compiler.JIT.prepare!(mod)
Compiler.JIT.prepare!(mod)
end
mstr
else
Expand Down
9 changes: 8 additions & 1 deletion src/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ end
const DumpPreCallConv = Ref(false)
const DumpPostCallConv = Ref(false)

function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true)
function fixup_callconv!(mod::LLVM.Module, tm::LLVM.TargetMachine)
addr13NoAlias(mod)
removeDeadArgs!(mod, tm, #=post_gc_fixup=#false)
if DumpPreCallConv[]
Expand Down Expand Up @@ -397,6 +397,13 @@ function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool
),
)
end
return
end

function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true; callconv::Bool = true)
if callconv
fixup_callconv!(mod, tm)
end
@dispose pb = NewPMPassBuilder() begin
registerEnzymeAndPassPipeline!(pb)
register!(pb, ReinsertGCMarkerPass())
Expand Down
8 changes: 4 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -562,11 +562,11 @@ function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType
end

if ekind == "enzyme_sret"
ety = parse(UInt, LLVM.value(attr))
ety = Base.reinterpret(LLVM.API.LLVMTypeRef, ety)
ety = LLVM.LLVMType(ety)
ety = parse(UInt, LLVM.value(attr))
ety = Base.reinterpret(LLVM.API.LLVMTypeRef, ety)
ety = LLVM.LLVMType(ety)
if !LLVM.is_opaque(vt)
@assert ety == eltype(vt)
@assert ety == eltype(vt) "Mismatched sret type $(string(fn))\nidx=$idx\nety ($(string(ety))) != eltype(vt) (vt = $(string(vt)))"
end

return ety
Expand Down
19 changes: 18 additions & 1 deletion test/ext/staticarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,21 @@ end
@test res ≈ [1.0, 0.0]
res = Enzyme.gradient(Enzyme.Forward, unstable_fun, inp)[1]
@test res ≈ [1.0, 0.0]
end
end

function inner_forhess(x)
return tanh.(x)
end

function for_hess(x)
return sum(inner_forhess(x))
end

grad_forhess(x) = autodiff(Reverse, for_hess, Active, Active(x))[1][1]
hess(x) = jacobian(Forward, grad_forhess, x)[1]

@testset "StaticArrays hessian" begin
x = @SVector zeros(10)
res = [-2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 -2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 -2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 -2.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 -2.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 -2.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 -2.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -2.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -2.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -2.0]
@test jacobian(Forward, grad_forhess, x)[1] ≈ res
end
Loading