Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions src/jlgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,22 +180,51 @@ const GLOBAL_CI_CACHES_LOCK = ReentrantLock()

function CC.setindex!(cache::CodeCache, ci::CodeInstance, mi::MethodInstance)
# make sure the invalidation callback is attached to the method instance
callback(mi, max_world) = invalidate_code_cache(cache, mi, max_world)
add_codecache_callback!(cache, mi)
cis = get!(cache.dict, mi, CodeInstance[])
push!(cis, ci)
end

# invalidation (like invalidate_method_instance, but for our cache)
struct CodeCacheCallback
cache::CodeCache
end

@static if VERSION ≥ v"1.11.0-DEV.798"

function add_codecache_callback!(cache::CodeCache, mi::MethodInstance)
callback = CodeCacheCallback(cache)
CC.add_invalidation_callback!(callback, mi)
end
function (callback::CodeCacheCallback)(replaced::MethodInstance, max_world::UInt32)
cis = get(callback.cache.dict, replaced, nothing)
if cis === nothing
return
end
for ci in cis
if ci.max_world == ~0 % Csize_t
@assert ci.min_world - 1 <= max_world "attempting to set illogical constraints"
ci.max_world = max_world
end
@assert ci.max_world <= max_world
end
end

else

function add_codecache_callback!(cache::CodeCache, mi::MethodInstance)
callback = CodeCacheCallback(cache)
if !isdefined(mi, :callbacks)
mi.callbacks = Any[callback]
elseif !in(callback, mi.callbacks)
push!(mi.callbacks, callback)
end

cis = get!(cache.dict, mi, CodeInstance[])
push!(cis, ci)
end

# invalidation (like invalidate_method_instance, but for our cache)
function invalidate_code_cache(cache::CodeCache, replaced::MethodInstance, max_world, seen=Set{MethodInstance}())
function (callback::CodeCacheCallback)(replaced::MethodInstance, max_world::UInt32,
seen::Set{MethodInstance}=Set{MethodInstance}())
push!(seen, replaced)

cis = get(cache.dict, replaced, nothing)
cis = get(callback.cache.dict, replaced, nothing)
if cis === nothing
return
end
Expand Down Expand Up @@ -225,11 +254,12 @@ function invalidate_code_cache(cache::CodeCache, replaced::MethodInstance, max_w
# replaced.backedges = Any[]

for mi in backedges
invalidate_code_cache(cache, mi::MethodInstance, max_world, seen)
callback(mi::MethodInstance, max_world, seen)
end
end
end

end

## method overrides

Expand Down