Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

Expand All @@ -30,6 +31,7 @@ GPUArraysCore = "0.1"
IteratorInterfaceExtensions = "1"
RecipesBase = "0.7, 0.8, 1.0"
StaticArraysCore = "1.1"
SymbolicIndexingInterface = "0.1"
Tables = "1"
ZygoteRules = "0.2"
julia = "1.6"
Expand Down
1 change: 1 addition & 0 deletions src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module RecursiveArrayTools
using DocStringExtensions
using RecipesBase, StaticArraysCore, Statistics,
ArrayInterfaceCore, LinearAlgebra
using SymbolicIndexingInterface

import ChainRulesCore
import ChainRulesCore: NoTangent
Expand Down
4 changes: 2 additions & 2 deletions src/tabletraits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ function Tables.rows(A::AbstractDiffEqArray)
N = length(A.u[1])
names = [
:timestamp,
(A.syms !== nothing ? (A.syms[i] for i in 1:N) :
(A.sc !== nothing && A.sc.syms !== nothing ? (A.sc.syms[i] for i in 1:N) :
(Symbol("value", i) for i in 1:N))...,
]
types = Type[eltype(A.t), (eltype(A.u[1]) for _ in 1:N)...]
else
names = [:timestamp, A.syms !== nothing ? A.syms[1] : :value]
names = [:timestamp, A.sc !== nothing && A.sc.syms !== nothing ? A.sc.syms[1] : :value]
types = Type[eltype(A.t), VT]
end
return AbstractDiffEqArrayRows(names, types, A.t, A.u)
Expand Down
70 changes: 41 additions & 29 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ A[1,:] # all time periods for f(t)
A.t
```
"""
mutable struct DiffEqArray{T, N, A, B, C, D, E, F} <: AbstractDiffEqArray{T, N, A}
mutable struct DiffEqArray{T, N, A, B, C, E, F} <: AbstractDiffEqArray{T, N, A}
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
t::B
syms::C
indepsym::D
sc::C
observed::E
p::F
end
Expand Down Expand Up @@ -94,11 +93,23 @@ VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N} = VectorOfArray{
VectorOfArray(vec::AbstractVector) = VectorOfArray(vec, (size(vec[1])..., length(vec)))
VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT<:AbstractArray{T, N}} = VectorOfArray{T, N+1, typeof(vec)}(vec)

DiffEqArray(vec::AbstractVector{T}, ts, ::NTuple{N}, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) where {T, N} = DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)}(vec, ts, syms, indepsym, observed, p)
function DiffEqArray(vec::AbstractVector{T}, ts, ::NTuple{N}, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) where {T, N}
sc = if isnothing(indepsym) || indepsym isa AbstractArray
SymbolCache{typeof(syms),typeof(indepsym),Nothing}(syms, indepsym, nothing)
else
SymbolCache{typeof(syms),Vector{typeof(indepsym)},Nothing}(syms, [indepsym], nothing)
end
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(sc), typeof(observed), typeof(p)}(vec, ts, sc, observed, p)
end
# Assume that the first element is representative of all other elements
DiffEqArray(vec::AbstractVector,ts::AbstractVector, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) = DiffEqArray(vec, ts, (size(vec[1])..., length(vec)), syms, indepsym, observed, p)
function DiffEqArray(vec::AbstractVector{VT},ts::AbstractVector, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) where {T, N, VT<:AbstractArray{T, N}}
DiffEqArray{T, N+1, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)}(vec, ts, syms, indepsym, observed, p)
sc = if isnothing(indepsym) || indepsym isa AbstractArray
SymbolCache{typeof(syms),typeof(indepsym),Nothing}(syms, indepsym, nothing)
else
SymbolCache{typeof(syms),Vector{typeof(indepsym)},Nothing}(syms, [indepsym], nothing)
end
DiffEqArray{T, N+1, typeof(vec), typeof(ts), typeof(sc), typeof(observed), typeof(p)}(vec, ts, sc, observed, p)
end

# Interface for the linear indexing. This is just a view of the underlying nested structure
Expand Down Expand Up @@ -138,37 +149,39 @@ Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int,::Co
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, ::Colon,i::Int) where {T, N} = A.u[i]
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int,II::AbstractArray{Int}) where {T, N} = [A.u[j][i] for j in II]
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym) where {T, N}
if issymbollike(sym) && A.syms !== nothing
i = findfirst(isequal(Symbol(sym)),A.syms)
else
i = sym
end

if i === nothing
if issymbollike(sym) && A.indepsym !== nothing && Symbol(sym) == A.indepsym
A.t
if issymbollike(sym) && !isnothing(A.sc)
if is_indep_sym(A.sc, sym)
return A.t
elseif is_state_sym(A.sc, sym)
return getindex.(A.u, state_sym_to_index(A.sc, sym))
elseif is_param_sym(A.sc, sym)
return A.p[param_sym_to_index(A.sc, sym)]
else
return observed(A, sym, :)
end
elseif all(issymbollike, sym) && !isnothing(A.sc)
if all(Base.Fix1(is_param_sym, A.sc), sym)
return getindex.((A,), sym)
else
observed(A,sym,:)
return [getindex.((A,), sym, i) for i in eachindex(A.t)]
end
else
Base.getindex.(A.u, i)
return getindex.(A.u, sym)
end
end
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym,args...) where {T, N}
if issymbollike(sym) && A.syms !== nothing
i = findfirst(isequal(Symbol(sym)),A.syms)
else
i = sym
end

if i === nothing
if issymbollike(sym) && A.indepsym !== nothing && Symbol(sym) == A.indepsym
A.t[args...]
if issymbollike(sym) && !isnothing(A.sc)
if is_indep_sym(A.sc, sym)
return A.t[args...]
elseif is_state_sym(A.sc, sym)
return A[sym][args...]
else
observed(A,sym,args...)
return observed(A, sym, args...)
end
elseif all(issymbollike, sym) && !isnothing(A.sc)
return reduce(vcat, map(s -> A[s, args...]', sym))
else
Base.getindex.(A.u, i, args...)
return getindex.(A.u, sym)
end
end
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, I::Int...) where {T, N} = A.u[I[end]][Base.front(I)...]
Expand Down Expand Up @@ -230,8 +243,7 @@ tuples(VA::DiffEqArray) = tuple.(VA.t,VA.u)
Base.copy(VA::AbstractDiffEqArray) = typeof(VA)(
copy(VA.u),
copy(VA.t),
(VA.syms===nothing) ? nothing : copy(VA.syms),
(VA.indepsym===nothing) ? nothing : copy(VA.indepsym),
(VA.sc===nothing) ? nothing : copy(VA.sc),
(VA.observed===nothing) ? nothing : copy(VA.observed),
(VA.p===nothing) ? nothing : copy(VA.p)
)
Expand Down
15 changes: 15 additions & 0 deletions test/symbolic_indexing_interface_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using RecursiveArrayTools, Test

t = 0.0:0.1:1.0
f(x) = 2x
f2(x) = 3x

dx = DiffEqArray([[f(x), f2(x)] for x in t], t, [:a, :b], :t)
@test dx[:t] == t
@test dx[:a] == [f(x) for x in t]
@test dx[:b] == [f2(x) for x in t]

dx = DiffEqArray([[f(x), f2(x)] for x in t], t, [:a, :b], [:t])
@test dx[:t] == t
dx = DiffEqArray([[f(x), f2(x)] for x in t], t, [:a, :b])
@test_throws Exception dx[nothing] # make sure it isn't storing [nothing] as indepsym