diff --git a/ext/RecursiveArrayToolsStructArraysExt.jl b/ext/RecursiveArrayToolsStructArraysExt.jl index 80c8b71f..a77ca20a 100644 --- a/ext/RecursiveArrayToolsStructArraysExt.jl +++ b/ext/RecursiveArrayToolsStructArraysExt.jl @@ -3,7 +3,8 @@ module RecursiveArrayToolsStructArraysExt import RecursiveArrayTools, StructArrays RecursiveArrayTools.rewrap(::StructArrays.StructArray, u) = StructArrays.StructArray(u) -using RecursiveArrayTools: VectorOfArray +using RecursiveArrayTools: VectorOfArray, VectorOfArrayStyle, ArrayInterface, unpack_voa, + narrays, StaticArraysCore using StructArrays: StructArray const VectorOfStructArray{T, N} = VectorOfArray{T, N, <:StructArray} @@ -17,6 +18,7 @@ const VectorOfStructArray{T, N} = VectorOfArray{T, N, <:StructArray} # # To avoid this, we can materialize a struct entry, modify it, and then use `setindex!` # with the modified struct entry. +# function Base.setindex!(VA::VectorOfStructArray{T, N}, v, I::Int...) where {T, N} u_I = VA.u[I[end]] @@ -24,4 +26,37 @@ function Base.setindex!(VA::VectorOfStructArray{T, N}, v, return VA.u[I[end]] = u_I end +for (type, N_expr) in [ + (Broadcast.Broadcasted{<:VectorOfArrayStyle}, :(narrays(bc))), + (Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}, :(length(dest.u))) +] + @eval @inline function Base.copyto!(dest::VectorOfStructArray, + bc::$type) + bc = Broadcast.flatten(bc) + N = $N_expr + @inbounds for i in 1:N + dest_i = dest[:, i] + if dest_i isa AbstractArray + if ArrayInterface.ismutable(dest_i) + copyto!(dest_i, unpack_voa(bc, i)) + else + unpacked = unpack_voa(bc, i) + arr_type = StaticArraysCore.similar_type(dest_i) + dest_i = if length(unpacked) == 1 && length(dest_i) == 1 + arr_type(unpacked[1]) + elseif length(unpacked) == 1 + fill(copy(unpacked), arr_type) + else + arr_type(unpacked[j] for j in eachindex(unpacked)) + end + end + else + dest_i = copy(unpack_voa(bc, i)) + end + dest[:, i] = dest_i + end + dest + end +end + end diff --git a/test/basic_indexing.jl b/test/basic_indexing.jl index 37c24857..442e761f 100644 --- a/test/basic_indexing.jl +++ b/test/basic_indexing.jl @@ -265,9 +265,18 @@ num_allocs = @allocations foo!(u_matrix) # check VectorOfArray indexing for a StructArray of mutable structs using StructArrays -using StaticArrays: MVector +using StaticArrays: MVector, SVector x = VectorOfArray(StructArray{MVector{1, Float64}}(ntuple(_ -> [1.0, 2.0], 1))) +y = 2 * x -# check VectorOfArray assignment +# check mutable VectorOfArray assignment and broadcast x[1, 1] = 10 @test x[1, 1] == 10 +@. x = y +@test all(all.(y .== x)) + +# check immutable VectorOfArray broadcast +x = VectorOfArray(StructArray{SVector{1, Float64}}(ntuple(_ -> [1.0, 2.0], 1))) +y = 2 * x +@. x = y +@test all(all.(y .== x))