Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Adapt"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "2.4.0"
version = "3.0.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
50 changes: 25 additions & 25 deletions src/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,30 @@ permutation(::PermutedDimsArray{T,N,perm}) where {T,N,perm} = perm
export WrappedArray

adapt_structure(to, A::SubArray) =
SubArray(adapt(to, parent(A)), adapt(to, parentindices(A)))
SubArray(adapt(to, Base.parent(A)), adapt(to, parentindices(A)))
adapt_structure(to, A::Base.LogicalIndex) =
Base.LogicalIndex(adapt(to, A.mask))
adapt_structure(to, A::PermutedDimsArray) =
PermutedDimsArray(adapt(to, parent(A)), permutation(A))
PermutedDimsArray(adapt(to, Base.parent(A)), permutation(A))
adapt_structure(to, A::Base.ReshapedArray) =
Base.reshape(adapt(to, parent(A)), size(A))
Base.reshape(adapt(to, Base.parent(A)), size(A))
adapt_structure(to, A::Base.ReinterpretArray) =
Base.reinterpret(eltype(A), adapt(to, parent(A)))
Base.reinterpret(Base.eltype(A), adapt(to, Base.parent(A)))

adapt_structure(to, A::LinearAlgebra.Adjoint) =
LinearAlgebra.adjoint(adapt(to, parent(A)))
LinearAlgebra.adjoint(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.Transpose) =
LinearAlgebra.transpose(adapt(to, parent(A)))
LinearAlgebra.transpose(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.LowerTriangular) =
LinearAlgebra.LowerTriangular(adapt(to, parent(A)))
LinearAlgebra.LowerTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.UnitLowerTriangular) =
LinearAlgebra.UnitLowerTriangular(adapt(to, parent(A)))
LinearAlgebra.UnitLowerTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.UpperTriangular) =
LinearAlgebra.UpperTriangular(adapt(to, parent(A)))
LinearAlgebra.UpperTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.UnitUpperTriangular) =
LinearAlgebra.UnitUpperTriangular(adapt(to, parent(A)))
LinearAlgebra.UnitUpperTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.Diagonal) =
LinearAlgebra.Diagonal(adapt(to, parent(A)))
LinearAlgebra.Diagonal(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.Tridiagonal) =
LinearAlgebra.Tridiagonal(adapt(to, A.dl), adapt(to, A.d), adapt(to, A.du))

Expand Down Expand Up @@ -103,26 +103,26 @@ WrappedArray{T,N,Src,Dst} = Union{
# https://github.com/JuliaLang/julia/pull/31563

# accessors for extracting information about the wrapper type
Base.ndims(::Type{<:Base.LogicalIndex}) = 1
Base.ndims(::Type{<:LinearAlgebra.Adjoint}) = 2
Base.ndims(::Type{<:LinearAlgebra.Transpose}) = 2
Base.ndims(::Type{<:LinearAlgebra.LowerTriangular}) = 2
Base.ndims(::Type{<:LinearAlgebra.UnitLowerTriangular}) = 2
Base.ndims(::Type{<:LinearAlgebra.UpperTriangular}) = 2
Base.ndims(::Type{<:LinearAlgebra.UnitUpperTriangular}) = 2
Base.ndims(::Type{<:LinearAlgebra.Diagonal}) = 2
Base.ndims(::Type{<:LinearAlgebra.Tridiagonal}) = 2
Base.ndims(::Type{<:WrappedArray{<:Any,N}}) where {N} = N

Base.eltype(::Type{<:WrappedArray{T}}) where {T} = T # every wrapper has a T typevar
ndims(::Type{<:Base.LogicalIndex}) = 1
ndims(::Type{<:LinearAlgebra.Adjoint}) = 2
ndims(::Type{<:LinearAlgebra.Transpose}) = 2
ndims(::Type{<:LinearAlgebra.LowerTriangular}) = 2
ndims(::Type{<:LinearAlgebra.UnitLowerTriangular}) = 2
ndims(::Type{<:LinearAlgebra.UpperTriangular}) = 2
ndims(::Type{<:LinearAlgebra.UnitUpperTriangular}) = 2
ndims(::Type{<:LinearAlgebra.Diagonal}) = 2
ndims(::Type{<:LinearAlgebra.Tridiagonal}) = 2
ndims(::Type{<:WrappedArray{<:Any,N}}) where {N} = N

eltype(::Type{<:WrappedArray{T}}) where {T} = T # every wrapper has a T typevar

for T in [:(Base.LogicalIndex{<:Any,<:Src}),
:(PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:Src}),
:(WrappedReinterpretArray{<:Any,<:Any,<:Src}),
:(WrappedReshapedArray{<:Any,<:Any,<:Src}),
:(WrappedSubArray{<:Any,<:Any,<:Src})]
@eval begin
Base.parent(::Type{<:$T}) where {Src} = Src.name.wrapper
parent(::Type{<:$T}) where {Src} = Src.name.wrapper
end
end
Base.parent(::Type{<:WrappedArray{<:Any,<:Any,<:Any,Dst}}) where {Dst} = Dst.name.wrapper
parent(::Type{<:WrappedArray{<:Any,<:Any,<:Any,Dst}}) where {Dst} = Dst.name.wrapper
8 changes: 4 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ const d = CustomArray{Float64,1}(rand(3))
@test_adapt CustomArray Tridiagonal(dl.arr, d.arr, du.arr) Tridiagonal(dl, d, du) AnyCustomArray

@testset "Extracting type information" begin
@test ndims(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == 2
@test ndims(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == 3
@test Adapt.ndims(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == 2
@test Adapt.ndims(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == 3

@test parent(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == Array
@test parent(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == Array
@test Adapt.parent(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == Array
@test Adapt.parent(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == Array
end