diff --git a/src/jlgen.jl b/src/jlgen.jl index f6002733..1a71ae2f 100644 --- a/src/jlgen.jl +++ b/src/jlgen.jl @@ -180,22 +180,51 @@ const GLOBAL_CI_CACHES_LOCK = ReentrantLock() function CC.setindex!(cache::CodeCache, ci::CodeInstance, mi::MethodInstance) # make sure the invalidation callback is attached to the method instance - callback(mi, max_world) = invalidate_code_cache(cache, mi, max_world) + add_codecache_callback!(cache, mi) + cis = get!(cache.dict, mi, CodeInstance[]) + push!(cis, ci) +end + +# invalidation (like invalidate_method_instance, but for our cache) +struct CodeCacheCallback + cache::CodeCache +end + +@static if VERSION ≥ v"1.11.0-DEV.798" + +function add_codecache_callback!(cache::CodeCache, mi::MethodInstance) + callback = CodeCacheCallback(cache) + CC.add_invalidation_callback!(callback, mi) +end +function (callback::CodeCacheCallback)(replaced::MethodInstance, max_world::UInt32) + cis = get(callback.cache.dict, replaced, nothing) + if cis === nothing + return + end + for ci in cis + if ci.max_world == ~0 % Csize_t + @assert ci.min_world - 1 <= max_world "attempting to set illogical constraints" + ci.max_world = max_world + end + @assert ci.max_world <= max_world + end +end + +else + +function add_codecache_callback!(cache::CodeCache, mi::MethodInstance) + callback = CodeCacheCallback(cache) if !isdefined(mi, :callbacks) mi.callbacks = Any[callback] elseif !in(callback, mi.callbacks) push!(mi.callbacks, callback) end - - cis = get!(cache.dict, mi, CodeInstance[]) - push!(cis, ci) end - -# invalidation (like invalidate_method_instance, but for our cache) -function invalidate_code_cache(cache::CodeCache, replaced::MethodInstance, max_world, seen=Set{MethodInstance}()) +function (callback::CodeCacheCallback)(replaced::MethodInstance, max_world::UInt32, + seen::Set{MethodInstance}=Set{MethodInstance}()) push!(seen, replaced) - cis = get(cache.dict, replaced, nothing) + cis = get(callback.cache.dict, replaced, nothing) if cis === nothing return end @@ -225,11 +254,12 @@ function invalidate_code_cache(cache::CodeCache, replaced::MethodInstance, max_w # replaced.backedges = Any[] for mi in backedges - invalidate_code_cache(cache, mi::MethodInstance, max_world, seen) + callback(mi::MethodInstance, max_world, seen) end end end +end ## method overrides