diff --git a/Project.toml b/Project.toml index 89750f77..b1047735 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" @@ -30,6 +31,7 @@ GPUArraysCore = "0.1" IteratorInterfaceExtensions = "1" RecipesBase = "0.7, 0.8, 1.0" StaticArraysCore = "1.1" +SymbolicIndexingInterface = "0.1" Tables = "1" ZygoteRules = "0.2" julia = "1.6" diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index 407346ad..caf6907b 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -7,6 +7,7 @@ module RecursiveArrayTools using DocStringExtensions using RecipesBase, StaticArraysCore, Statistics, ArrayInterfaceCore, LinearAlgebra +using SymbolicIndexingInterface import ChainRulesCore import ChainRulesCore: NoTangent diff --git a/src/tabletraits.jl b/src/tabletraits.jl index e7a70492..df4ccef7 100644 --- a/src/tabletraits.jl +++ b/src/tabletraits.jl @@ -7,12 +7,12 @@ function Tables.rows(A::AbstractDiffEqArray) N = length(A.u[1]) names = [ :timestamp, - (A.syms !== nothing ? (A.syms[i] for i in 1:N) : + (A.sc !== nothing && A.sc.syms !== nothing ? (A.sc.syms[i] for i in 1:N) : (Symbol("value", i) for i in 1:N))..., ] types = Type[eltype(A.t), (eltype(A.u[1]) for _ in 1:N)...] else - names = [:timestamp, A.syms !== nothing ? A.syms[1] : :value] + names = [:timestamp, A.sc !== nothing && A.sc.syms !== nothing ? A.sc.syms[1] : :value] types = Type[eltype(A.t), VT] end return AbstractDiffEqArrayRows(names, types, A.t, A.u) diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index ca3d36f1..a06ee2d9 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -53,11 +53,10 @@ A[1,:] # all time periods for f(t) A.t ``` """ -mutable struct DiffEqArray{T, N, A, B, C, D, E, F} <: AbstractDiffEqArray{T, N, A} +mutable struct DiffEqArray{T, N, A, B, C, E, F} <: AbstractDiffEqArray{T, N, A} u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}} t::B - syms::C - indepsym::D + sc::C observed::E p::F end @@ -94,11 +93,23 @@ VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N} = VectorOfArray{ VectorOfArray(vec::AbstractVector) = VectorOfArray(vec, (size(vec[1])..., length(vec))) VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT<:AbstractArray{T, N}} = VectorOfArray{T, N+1, typeof(vec)}(vec) -DiffEqArray(vec::AbstractVector{T}, ts, ::NTuple{N}, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) where {T, N} = DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)}(vec, ts, syms, indepsym, observed, p) +function DiffEqArray(vec::AbstractVector{T}, ts, ::NTuple{N}, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) where {T, N} + sc = if isnothing(indepsym) || indepsym isa AbstractArray + SymbolCache{typeof(syms),typeof(indepsym),Nothing}(syms, indepsym, nothing) + else + SymbolCache{typeof(syms),Vector{typeof(indepsym)},Nothing}(syms, [indepsym], nothing) + end + DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(sc), typeof(observed), typeof(p)}(vec, ts, sc, observed, p) +end # Assume that the first element is representative of all other elements DiffEqArray(vec::AbstractVector,ts::AbstractVector, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) = DiffEqArray(vec, ts, (size(vec[1])..., length(vec)), syms, indepsym, observed, p) function DiffEqArray(vec::AbstractVector{VT},ts::AbstractVector, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) where {T, N, VT<:AbstractArray{T, N}} - DiffEqArray{T, N+1, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)}(vec, ts, syms, indepsym, observed, p) + sc = if isnothing(indepsym) || indepsym isa AbstractArray + SymbolCache{typeof(syms),typeof(indepsym),Nothing}(syms, indepsym, nothing) + else + SymbolCache{typeof(syms),Vector{typeof(indepsym)},Nothing}(syms, [indepsym], nothing) + end + DiffEqArray{T, N+1, typeof(vec), typeof(ts), typeof(sc), typeof(observed), typeof(p)}(vec, ts, sc, observed, p) end # Interface for the linear indexing. This is just a view of the underlying nested structure @@ -138,37 +149,39 @@ Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int,::Co Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, ::Colon,i::Int) where {T, N} = A.u[i] Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int,II::AbstractArray{Int}) where {T, N} = [A.u[j][i] for j in II] Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym) where {T, N} - if issymbollike(sym) && A.syms !== nothing - i = findfirst(isequal(Symbol(sym)),A.syms) - else - i = sym - end - - if i === nothing - if issymbollike(sym) && A.indepsym !== nothing && Symbol(sym) == A.indepsym - A.t + if issymbollike(sym) && !isnothing(A.sc) + if is_indep_sym(A.sc, sym) + return A.t + elseif is_state_sym(A.sc, sym) + return getindex.(A.u, state_sym_to_index(A.sc, sym)) + elseif is_param_sym(A.sc, sym) + return A.p[param_sym_to_index(A.sc, sym)] + else + return observed(A, sym, :) + end + elseif all(issymbollike, sym) && !isnothing(A.sc) + if all(Base.Fix1(is_param_sym, A.sc), sym) + return getindex.((A,), sym) else - observed(A,sym,:) + return [getindex.((A,), sym, i) for i in eachindex(A.t)] end else - Base.getindex.(A.u, i) + return getindex.(A.u, sym) end end Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym,args...) where {T, N} - if issymbollike(sym) && A.syms !== nothing - i = findfirst(isequal(Symbol(sym)),A.syms) - else - i = sym - end - - if i === nothing - if issymbollike(sym) && A.indepsym !== nothing && Symbol(sym) == A.indepsym - A.t[args...] + if issymbollike(sym) && !isnothing(A.sc) + if is_indep_sym(A.sc, sym) + return A.t[args...] + elseif is_state_sym(A.sc, sym) + return A[sym][args...] else - observed(A,sym,args...) + return observed(A, sym, args...) end + elseif all(issymbollike, sym) && !isnothing(A.sc) + return reduce(vcat, map(s -> A[s, args...]', sym)) else - Base.getindex.(A.u, i, args...) + return getindex.(A.u, sym) end end Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, I::Int...) where {T, N} = A.u[I[end]][Base.front(I)...] @@ -230,8 +243,7 @@ tuples(VA::DiffEqArray) = tuple.(VA.t,VA.u) Base.copy(VA::AbstractDiffEqArray) = typeof(VA)( copy(VA.u), copy(VA.t), - (VA.syms===nothing) ? nothing : copy(VA.syms), - (VA.indepsym===nothing) ? nothing : copy(VA.indepsym), + (VA.sc===nothing) ? nothing : copy(VA.sc), (VA.observed===nothing) ? nothing : copy(VA.observed), (VA.p===nothing) ? nothing : copy(VA.p) ) diff --git a/test/runtests.jl b/test/runtests.jl index 69eec144..890866c8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,6 +23,7 @@ if GROUP == "Core" || GROUP == "All" @time @testset "Utils Tests" begin include("utils_test.jl") end @time @testset "Partitions Tests" begin include("partitions_test.jl") end @time @testset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end + @time @testset "SymbolicIndexingInterface API test" begin include("symbolic_indexing_interface_test.jl") end @time @testset "VecOfArr Interface Tests" begin include("interface_tests.jl") end @time @testset "Table traits" begin include("tabletraits.jl") end @time @testset "StaticArrays Tests" begin include("copy_static_array_test.jl") end diff --git a/test/symbolic_indexing_interface_test.jl b/test/symbolic_indexing_interface_test.jl new file mode 100644 index 00000000..6afee2f1 --- /dev/null +++ b/test/symbolic_indexing_interface_test.jl @@ -0,0 +1,15 @@ +using RecursiveArrayTools, Test + +t = 0.0:0.1:1.0 +f(x) = 2x +f2(x) = 3x + +dx = DiffEqArray([[f(x), f2(x)] for x in t], t, [:a, :b], :t) +@test dx[:t] == t +@test dx[:a] == [f(x) for x in t] +@test dx[:b] == [f2(x) for x in t] + +dx = DiffEqArray([[f(x), f2(x)] for x in t], t, [:a, :b], [:t]) +@test dx[:t] == t +dx = DiffEqArray([[f(x), f2(x)] for x in t], t, [:a, :b]) +@test_throws Exception dx[nothing] # make sure it isn't storing [nothing] as indepsym