diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl index d36e9af2a..6e5d38df7 100644 --- a/lib/JLArrays/src/JLArrays.jl +++ b/lib/JLArrays/src/JLArrays.jl @@ -89,15 +89,16 @@ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N} check_eltype(T) maxsize = prod(dims) * sizeof(T) - return GPUArrays.cached_alloc((JLArray, T, dims)) do + ref = GPUArrays.cached_alloc((JLArray, maxsize)) do data = Vector{UInt8}(undef, maxsize) - ref = DataRef(data) do data + DataRef(data) do data resize!(data, 0) end - obj = new{T, N}(ref, 0, dims) - finalizer(unsafe_free!, obj) - return obj - end::JLArray{T, N} + end + + obj = new{T, N}(ref, 0, dims) + finalizer(unsafe_free!, obj) + return obj end # low-level constructor for wrapping existing data diff --git a/src/host/abstractarray.jl b/src/host/abstractarray.jl index 150cd88b6..881e6e88c 100644 --- a/src/host/abstractarray.jl +++ b/src/host/abstractarray.jl @@ -53,17 +53,20 @@ end # per-object state, with a flag to indicate whether the object has been freed. # this is to support multiple calls to `unsafe_free!` on the same object, -# while only lowering the referene count of the underlying data once. +# while only lowering the reference count of the underlying data once. mutable struct DataRef{D} rc::RefCounted{D} freed::Bool + cached::Bool end -function DataRef(finalizer, data::D) where {D} - rc = RefCounted{D}(data, finalizer, Threads.Atomic{Int}(1)) - DataRef{D}(rc, false) +function DataRef(finalizer, ref::D) where {D} + rc = RefCounted{D}(ref, finalizer, Threads.Atomic{Int}(1)) + DataRef{D}(rc, false, false) end -DataRef(data; kwargs...) = DataRef(nothing, data; kwargs...) +DataRef(ref; kwargs...) = DataRef(nothing, ref; kwargs...) + +Base.sizeof(ref::DataRef) = sizeof(ref.rc[]) function Base.getindex(ref::DataRef) if ref.freed @@ -77,10 +80,16 @@ function Base.copy(ref::DataRef{D}) where {D} throw(ArgumentError("Attempt to copy a freed reference.")) end retain(ref.rc) - return DataRef{D}(ref.rc, false) + # copies of cached references are not managed by the cache, so + # we need to mark them as such to make sure their refcount can drop. + return DataRef{D}(ref.rc, false, false) end -function unsafe_free!(ref::DataRef, args...) +function unsafe_free!(ref::DataRef) + if ref.cached + # lifetimes of cached references are tied to the cache. + return + end if ref.freed # multiple frees *of the same object* are allowed. # we should only ever call `release` once per object, though, @@ -88,7 +97,7 @@ function unsafe_free!(ref::DataRef, args...) return end ref.freed = true - release(ref.rc, args...) + release(ref.rc) return end diff --git a/src/host/alloc_cache.jl b/src/host/alloc_cache.jl index 22775b2dd..442f375bb 100644 --- a/src/host/alloc_cache.jl +++ b/src/host/alloc_cache.jl @@ -8,8 +8,8 @@ end mutable struct AllocCache lock::ReentrantLock - busy::Dict{UInt64, Vector{Any}} # hash(key) => GPUArray[] - free::Dict{UInt64, Vector{Any}} + busy::Dict{UInt64, Vector{DataRef}} + free::Dict{UInt64, Vector{DataRef}} function AllocCache() cache = new( @@ -24,8 +24,8 @@ end function get_pool!(cache::AllocCache, pool::Symbol, uid::UInt64) pool = getproperty(cache, pool) uid_pool = get(pool, uid, nothing) - if uid_pool ≡ nothing - uid_pool = Base.@lock cache.lock pool[uid] = Any[] + if uid_pool === nothing + uid_pool = pool[uid] = DataRef[] end return uid_pool end @@ -33,34 +33,39 @@ end function cached_alloc(f, key) cache = ALLOC_CACHE[] if cache === nothing - return f()::AbstractGPUArray + return f()::DataRef end - x = nothing + ref = nothing uid = hash(key) - busy_pool = get_pool!(cache, :busy, uid) - free_pool = get_pool!(cache, :free, uid) - isempty(free_pool) && (x = f()::AbstractGPUArray) + Base.@lock cache.lock begin + free_pool = get_pool!(cache, :free, uid) + + if !isempty(free_pool) + ref = Base.@lock cache.lock pop!(free_pool) + end + end - while !isempty(free_pool) && x ≡ nothing - tmp = Base.@lock cache.lock pop!(free_pool) - # Array was manually freed via `unsafe_free!`. - GPUArrays.storage(tmp).freed && continue - x = tmp + if ref === nothing + ref = f()::DataRef + ref.cached = true end - x ≡ nothing && (x = f()::AbstractGPUArray) - Base.@lock cache.lock push!(busy_pool, x) - return x + Base.@lock cache.lock begin + busy_pool = get_pool!(cache, :busy, uid) + push!(busy_pool, ref) + end + + return ref end function free_busy!(cache::AllocCache) - for uid in cache.busy.keys - busy_pool = get_pool!(cache, :busy, uid) - isempty(busy_pool) && continue + Base.@lock cache.lock begin + for uid in keys(cache.busy) + busy_pool = get_pool!(cache, :busy, uid) + isempty(busy_pool) && continue - Base.@lock cache.lock begin free_pool = get_pool!(cache, :free, uid) append!(free_pool, busy_pool) empty!(busy_pool) @@ -71,14 +76,13 @@ end function unsafe_free!(cache::AllocCache) Base.@lock cache.lock begin - for (_, pool) in cache.busy - isempty(pool) || error( - "Invalidating allocations cache that's currently in use. " * - "Invalidating inside `@cached` is not allowed." - ) + for pool in values(cache.busy) + isempty(pool) || error("Cannot invalidate a cache that's in active use") end - for (_, pool) in cache.free - map(unsafe_free!, pool) + for pool in values(cache.free), ref in pool + # release the reference + ref.cached = false + unsafe_free!(ref) end empty!(cache.free) end @@ -143,13 +147,11 @@ GPUArrays.unsafe_free!(cache) See [`@uncached`](@ref). """ macro cached(cache, expr) + try_expr = :(@with $(esc(ALLOC_CACHE)) => cache $(esc(expr))) + fin_expr = :(free_busy!($(esc(cache)))) return quote - cache = $(esc(cache)) - GC.@preserve cache begin - res = @with $(esc(ALLOC_CACHE)) => cache $(esc(expr)) - free_busy!(cache) - res - end + local cache = $(esc(cache)) + GC.@preserve cache $(Expr(:tryfinally, try_expr, fin_expr)) end end diff --git a/test/testsuite/alloc_cache.jl b/test/testsuite/alloc_cache.jl index b032c8bda..e63ca6c2c 100644 --- a/test/testsuite/alloc_cache.jl +++ b/test/testsuite/alloc_cache.jl @@ -2,42 +2,98 @@ if AT <: AbstractGPUArray cache = GPUArrays.AllocCache() + # first allocation populates the cache T, dims = Float32, (1, 2, 3) GPUArrays.@cached cache begin - x1 = AT(zeros(T, dims)) + cached1 = AT(zeros(T, dims)) end - @test sizeof(cache) == sizeof(T) * prod(dims) + @test sizeof(cache) == sizeof(cached1) key = first(keys(cache.free)) @test length(cache.free[key]) == 1 @test length(cache.busy[key]) == 0 - @test x1 === cache.free[key][1] + @test cache.free[key][1] === GPUArrays.storage(cached1) - # Second allocation hits cache. + # second allocation hits the cache GPUArrays.@cached cache begin - x2 = AT(zeros(T, dims)) - # Does not hit the cache. - GPUArrays.@uncached x_free = AT(zeros(T, dims)) + cached2 = AT(zeros(T, dims)) + + # explicitly uncached ones don't + GPUArrays.@uncached uncached = AT(zeros(T, dims)) + end + @test sizeof(cache) == sizeof(cached2) + key = first(keys(cache.free)) + @test length(cache.free[key]) == 1 + @test length(cache.busy[key]) == 0 + @test cache.free[key][1] === GPUArrays.storage(cached2) + @test uncached !== cached2 + + # compatible shapes should also hit the cache + dims = (3, 2, 1) + GPUArrays.@cached cache begin + cached3 = AT(zeros(T, dims)) end - @test sizeof(cache) == sizeof(T) * prod(dims) + @test sizeof(cache) == sizeof(cached3) key = first(keys(cache.free)) @test length(cache.free[key]) == 1 @test length(cache.busy[key]) == 0 - @test x2 === cache.free[key][1] - @test x_free !== x2 + @test cache.free[key][1] === GPUArrays.storage(cached3) - # Third allocation is of different shape - allocates. + # as should compatible eltypes + T = Int32 + GPUArrays.@cached cache begin + cached4 = AT(zeros(T, dims)) + end + @test sizeof(cache) == sizeof(cached4) + key = first(keys(cache.free)) + @test length(cache.free[key]) == 1 + @test length(cache.busy[key]) == 0 + @test cache.free[key][1] === GPUArrays.storage(cached4) + + # different shapes should trigger a new allocation dims = (2, 2) GPUArrays.@cached cache begin - x3 = AT(zeros(T, dims)) + cached5 = AT(zeros(T, dims)) + + # we're allowed to early free arrays, which should be a no-op for cached data + GPUArrays.unsafe_free!(cached5) end + @test sizeof(cache) == sizeof(cached4) + sizeof(cached5) _keys = collect(keys(cache.free)) key2 = _keys[findfirst(i -> i != key, _keys)] @test length(cache.free[key]) == 1 @test length(cache.free[key2]) == 1 - @test x3 === cache.free[key2][1] + @test cache.free[key2][1] === GPUArrays.storage(cached5) + + # we should be able to re-use the early-freed + GPUArrays.@cached cache begin + cached5 = AT(zeros(T, dims)) + end + + # exceptions shouldn't cause issues + @test_throws "Allowed exception" GPUArrays.@cached cache begin + AT(zeros(T, dims)) + error("Allowed exception") + end + # NOTE: this should remaint the last test before calling `unsafe_free!` below, + # as it caught an erroneous assertion in the original code. - # Freeing all memory held by cache. + # freeing all memory held by cache should free all allocations + @test !GPUArrays.storage(cached1).freed + @test GPUArrays.storage(cached1).cached + @test !GPUArrays.storage(cached5).freed + @test GPUArrays.storage(cached5).cached + @test !GPUArrays.storage(uncached).freed + @test !GPUArrays.storage(uncached).cached GPUArrays.unsafe_free!(cache) @test sizeof(cache) == 0 + @test GPUArrays.storage(cached1).freed + @test !GPUArrays.storage(cached1).cached + @test GPUArrays.storage(cached5).freed + @test !GPUArrays.storage(cached5).cached + @test !GPUArrays.storage(uncached).freed + ## test that the underlying data was freed as well + @test GPUArrays.storage(cached1).rc.count[] == 0 + @test GPUArrays.storage(cached5).rc.count[] == 0 + @test GPUArrays.storage(uncached).rc.count[] == 1 end end