diff --git a/Project.toml b/Project.toml index 7b267372..9e6ad1dc 100644 --- a/Project.toml +++ b/Project.toml @@ -3,19 +3,22 @@ uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" version = "0.4.2" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] +Adapt = "1" DataAPI = "1" Tables = "1" julia = "1" [extras] +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" [targets] -test = ["Test", "OffsetArrays", "PooledArrays", "WeakRefStrings"] +test = ["Test", "GPUArrays", "OffsetArrays", "PooledArrays", "WeakRefStrings"] diff --git a/src/StructArrays.jl b/src/StructArrays.jl index c9f52df1..01406726 100644 --- a/src/StructArrays.jl +++ b/src/StructArrays.jl @@ -16,4 +16,8 @@ include("groupjoin.jl") include("lazy.jl") include("tables.jl") +# Use Adapt allows for automatic conversion of CPU to GPU StructArrays +import Adapt +Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s) + end # module diff --git a/src/structarray.jl b/src/structarray.jl index 8a1ea84c..4d1d517f 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -134,12 +134,11 @@ Base.axes(s::StructArray) = axes(fieldarrays(s)[1]) Base.axes(s::StructArray{<:Any, <:Any, <:EmptyTup}) = (1:0,) get_ith(cols::NamedTuple, I...) = get_ith(Tuple(cols), I...) -function get_ith(cols::NTuple{N, Any}, I...) where N - ntuple(N) do i - @inbounds res = getfield(cols, i)[I...] - return res - end +function get_ith(cols::Tuple, I...) + @inbounds r = first(cols)[I...] + return (r, get_ith(Base.tail(cols), I...)...) end +get_ith(::Tuple{}, I...) = () Base.@propagate_inbounds function Base.getindex(x::StructArray{T, <:Any, <:Any, CartesianIndex{N}}, I::Vararg{Int, N}) where {T, N} cols = fieldarrays(x) diff --git a/src/utils.jl b/src/utils.jl index 355d3168..9ac1f521 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -28,19 +28,25 @@ else const _getproperty = getproperty end -function _foreachfield(names, xs) +function _foreachfield(names, L) + vars = ntuple(i -> gensym(), L) exprs = Expr[] + for (i, v) in enumerate(vars) + push!(exprs, Expr(:(=), v, Expr(:call, :getfield, :xs, i))) + end for field in names sym = QuoteNode(field) - args = [Expr(:call, :_getproperty, :(getfield(xs, $j)), sym) for j in 1:length(xs)] + args = [Expr(:call, :_getproperty, var, sym) for var in vars] push!(exprs, Expr(:call, :f, args...)) end push!(exprs, :(return nothing)) return Expr(:block, exprs...) end -@generated foreachfield(::Type{<:NamedTuple{names}}, f, xs...) where {names} = _foreachfield(names, xs) -@generated foreachfield(::Type{<:NTuple{N, Any}}, f, xs...) where {N} = _foreachfield(Base.OneTo(N), xs) +@generated foreachfield(::Type{<:NamedTuple{names}}, f, xs::Vararg{Any, L}) where {names, L} = + _foreachfield(names, L) +@generated foreachfield(::Type{<:NTuple{N, Any}}, f, xs::Vararg{Any, L}) where {N, L} = + _foreachfield(Base.OneTo(N), L) foreachfield(f, x::T, xs...) where {T} = foreachfield(staticschema(T), f, x, xs...) diff --git a/test/runtests.jl b/test/runtests.jl index e9610602..66fb7605 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,8 @@ using StructArrays: staticschema, iscompatible, _promote_typejoin, append!! using OffsetArrays: OffsetArray import Tables, PooledArrays, WeakRefStrings using DataAPI: refarray, refvalue +using Adapt: adapt +import GPUArrays using Test @testset "index" begin @@ -700,3 +702,13 @@ end @test vcat(dest, StructVector(makeitr())) == append!!(copy(dest), makeitr()) end end + +@testset "adapt" begin + s = StructArray(a = 1:10, b = StructArray(c = 1:10, d = 1:10)) + t = adapt(Array, s) + @test propertynames(t) == (:a, :b) + @test s == t + @test t.a isa Array + @test t.b.c isa Array + @test t.b.d isa Array +end