diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 18f935e..cc5c380 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -22,6 +22,7 @@ jobs: version: - '1.6' # LTS - '1' + - '~1.10.0-0' - 'nightly' os: - ubuntu-latest diff --git a/Project.toml b/Project.toml index 7892859..aa1a89e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Tracker" uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -version = "0.2.28" +version = "0.2.30" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/lib/array.jl b/src/lib/array.jl index 001f302..c31f5bc 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -164,26 +164,44 @@ Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...) end for (T, S) in [(:TrackedArray, :TrackedArray), (:TrackedArray, :AbstractArray), (:AbstractArray, :TrackedArray)] - @eval Base.vcat(A::$T, B::$S, Cs::AbstractArray...) = track(vcat, A, B, Cs...) - @eval Base.hcat(A::$T, B::$S, Cs::AbstractArray...) = track(hcat, A, B, Cs...) + for op in (:vcat, :hcat) + @eval Base.$(op)(A::$T, B::$S, Cs::AbstractArray...) = track($(op), A, B, Cs...) + @eval Base.$(op)(A::$T, B::$S, Cs::TrackedArray...) = track($(op), A, B, Cs...) + @eval Base.$(op)(A::$T{<:Number}, B::$S{<:Number}, Cs::AbstractArray{<:Number}...) = track($(op), A, B, Cs...) + @eval Base.$(op)(A::$T{<:Number}, B::$S{<:Number}, Cs::TrackedArray{<:Number}...) = track($(op), A, B, Cs...) + @eval Base.$(op)(A::$T, B::$S) = track($(op), A, B) + end end for (T, S) in [(:TrackedVector, :TrackedVector), (:TrackedVector, :AbstractVector), (:AbstractVector, :TrackedVector)] - @eval Base.vcat(A::$T, B::$S, Cs::AbstractVector...) = track(vcat, A, B, Cs...) + for op in (:vcat, :hcat) + @eval Base.$(op)(A::$T, B::$S, Cs::AbstractVector...) = track($(op), A, B, Cs...) + @eval Base.$(op)(A::$T, B::$S, Cs::TrackedVector...) = track($(op), A, B, Cs...) + @eval Base.$(op)(A::$T{<:Number}, B::$S{<:Number}, Cs::AbstractVector{<:Number}...) = track($(op), A, B, Cs...) + @eval Base.$(op)(A::$T{<:Number}, B::$S{<:Number}, Cs::TrackedVector{<:Number}...) = track($(op), A, B, Cs...) + @eval Base.$(op)(A::$T, B::$S) = track($(op), A, B) + end end for (T, S) in [(:TrackedVecOrMat, :TrackedVecOrMat), (:TrackedVecOrMat, :AbstractVecOrMat), (:AbstractVecOrMat, :TrackedVecOrMat)] - @eval Base.vcat(A::$T, B::$S, Cs::AbstractVecOrMat...) = track(vcat, A, B, Cs...) - @eval Base.hcat(A::$T, B::$S, Cs::AbstractVecOrMat...) = track(hcat, A, B, Cs...) + for op in (:vcat, :hcat) + @eval Base.$(op)(A::$T, B::$S, Cs::AbstractVecOrMat...) = track($(op), A, B, Cs...) + @eval Base.$(op)(A::$T, B::$S, Cs::TrackedVecOrMat...) = track($(op), A, B, Cs...) + @eval Base.$(op)(A::$T{<:Number}, B::$S{<:Number}, Cs::AbstractVecOrMat{<:Number}...) = track($(op), A, B, Cs...) + @eval Base.$(op)(A::$T{<:Number}, B::$S{<:Number}, Cs::TrackedVecOrMat{<:Number}...) = track($(op), A, B, Cs...) + @eval Base.$(op)(A::$T, B::$S) = track($(op), A, B) + end end for (T, S) in [(:TrackedArray, :Real), (:Real, :TrackedArray), (:TrackedArray, :TrackedArray)] - @eval Base.vcat(A::$T, B::$S, Cs::Union{AbstractArray, Real}...) = track(vcat, A, B, Cs...) - @eval Base.hcat(A::$T, B::$S, Cs::Union{AbstractArray, Real}...) = track(hcat, A, B, Cs...) + @eval Base.vcat(A::$T, B::$S, Cs::Union{TrackedArray, AbstractArray, Real}...) = track(vcat, A, B, Cs...) + @eval Base.hcat(A::$T, B::$S, Cs::Union{TrackedArray, AbstractArray, Real}...) = track(hcat, A, B, Cs...) + if T == :Real || S == :Real + @eval Base.vcat(A::$T, B::$S) = track(vcat, A, B) + @eval Base.hcat(A::$T, B::$S) = track(hcat, A, B) + end end for (T, S) in [(:TrackedReal, :Real), (:Real, :TrackedReal), (:TrackedReal, :TrackedReal)] @eval Base.vcat(A::$T, B::$S, Cs::Real...) = track(vcat, A, B, Cs...) @eval Base.hcat(A::$T, B::$S, Cs::Real...) = track(hcat, A, B, Cs...) end -Base.vcat(A::TrackedVecOrMat{T1, <:AbstractArray}, B::TrackedVecOrMat{T2, <:AbstractArray}) where {T1, T2} = track(vcat, A, B) -Base.hcat(A::TrackedVecOrMat{T1, <:AbstractArray}, B::TrackedVecOrMat{T2, <:AbstractArray}) where {T1, T2} = track(hcat, A, B) Base.vcat(A::TrackedArray) = track(vcat, A) Base.hcat(A::TrackedArray) = track(hcat, A) diff --git a/test/tracker.jl b/test/tracker.jl index c6f03bd..46ca578 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -92,7 +92,7 @@ end @test gradtest(hcatf, rand(5), rand(5), rand(5,2)) @test gradtest(hcatf, rand(5)', rand(1,3)) @test gradtest(hcatf, rand(5), rand(5,2)) -end + end @testset "1-arg $catf" for catf in [vcat, cat1, rvcat, hcat, cat2, rhcat, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))] @test gradtest(catf, rand(5))