Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

Expand All @@ -30,6 +31,7 @@ GPUArraysCore = "0.1"
IteratorInterfaceExtensions = "1"
RecipesBase = "0.7, 0.8, 1.0"
StaticArraysCore = "1.1"
SymbolicIndexingInterface = "0.1"
Tables = "1"
ZygoteRules = "0.2"
julia = "1.6"
Expand Down
1 change: 1 addition & 0 deletions src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module RecursiveArrayTools
using DocStringExtensions
using RecipesBase, StaticArraysCore, Statistics,
ArrayInterfaceCore, LinearAlgebra
using SymbolicIndexingInterface

import ChainRulesCore
import ChainRulesCore: NoTangent
Expand Down
4 changes: 2 additions & 2 deletions src/tabletraits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ function Tables.rows(A::AbstractDiffEqArray)
N = length(A.u[1])
names = [
:timestamp,
(A.syms !== nothing ? (A.syms[i] for i in 1:N) :
(A.sc !== nothing && A.sc.syms !== nothing ? (A.sc.syms[i] for i in 1:N) :
(Symbol("value", i) for i in 1:N))...,
]
types = Type[eltype(A.t), (eltype(A.u[1]) for _ in 1:N)...]
else
names = [:timestamp, A.syms !== nothing ? A.syms[1] : :value]
names = [:timestamp, A.sc !== nothing && A.sc.syms !== nothing ? A.sc.syms[1] : :value]
types = Type[eltype(A.t), VT]
end
return AbstractDiffEqArrayRows(names, types, A.t, A.u)
Expand Down
70 changes: 41 additions & 29 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ A[1,:] # all time periods for f(t)
A.t
```
"""
mutable struct DiffEqArray{T, N, A, B, C, D, E, F} <: AbstractDiffEqArray{T, N, A}
mutable struct DiffEqArray{T, N, A, B, C, E, F} <: AbstractDiffEqArray{T, N, A}
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
t::B
syms::C
indepsym::D
sc::C
observed::E
p::F
end
Expand Down Expand Up @@ -94,11 +93,23 @@ VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N} = VectorOfArray{
VectorOfArray(vec::AbstractVector) = VectorOfArray(vec, (size(vec[1])..., length(vec)))
VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT<:AbstractArray{T, N}} = VectorOfArray{T, N+1, typeof(vec)}(vec)

DiffEqArray(vec::AbstractVector{T}, ts, ::NTuple{N}, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) where {T, N} = DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)}(vec, ts, syms, indepsym, observed, p)
function DiffEqArray(vec::AbstractVector{T}, ts, ::NTuple{N}, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) where {T, N}
sc = if isnothing(indepsym) || indepsym isa AbstractArray
SymbolCache{typeof(syms),typeof(indepsym),Nothing}(syms, indepsym, nothing)
else
SymbolCache{typeof(syms),Vector{typeof(indepsym)},Nothing}(syms, [indepsym], nothing)
end
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(sc), typeof(observed), typeof(p)}(vec, ts, sc, observed, p)
end
# Assume that the first element is representative of all other elements
DiffEqArray(vec::AbstractVector,ts::AbstractVector, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) = DiffEqArray(vec, ts, (size(vec[1])..., length(vec)), syms, indepsym, observed, p)
function DiffEqArray(vec::AbstractVector{VT},ts::AbstractVector, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) where {T, N, VT<:AbstractArray{T, N}}
DiffEqArray{T, N+1, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)}(vec, ts, syms, indepsym, observed, p)
sc = if isnothing(indepsym) || indepsym isa AbstractArray
SymbolCache{typeof(syms),typeof(indepsym),Nothing}(syms, indepsym, nothing)
else
SymbolCache{typeof(syms),Vector{typeof(indepsym)},Nothing}(syms, [indepsym], nothing)
end
DiffEqArray{T, N+1, typeof(vec), typeof(ts), typeof(sc), typeof(observed), typeof(p)}(vec, ts, sc, observed, p)
end

# Interface for the linear indexing. This is just a view of the underlying nested structure
Expand Down Expand Up @@ -138,37 +149,39 @@ Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int,::Co
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, ::Colon,i::Int) where {T, N} = A.u[i]
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int,II::AbstractArray{Int}) where {T, N} = [A.u[j][i] for j in II]
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym) where {T, N}
if issymbollike(sym) && A.syms !== nothing
i = findfirst(isequal(Symbol(sym)),A.syms)
else
i = sym
end

if i === nothing
if issymbollike(sym) && A.indepsym !== nothing && Symbol(sym) == A.indepsym
A.t
if issymbollike(sym) && !isnothing(A.sc)
if is_indep_sym(A.sc, sym)
return A.t
elseif is_state_sym(A.sc, sym)
return getindex.(A.u, state_sym_to_index(A.sc, sym))
elseif is_param_sym(A.sc, sym)
return A.p[param_sym_to_index(A.sc, sym)]
else
return observed(A, sym, :)
end
elseif all(issymbollike, sym) && !isnothing(A.sc)
if all(Base.Fix1(is_param_sym, A.sc), sym)
return getindex.((A,), sym)
else
observed(A,sym,:)
return [getindex.((A,), sym, i) for i in eachindex(A.t)]
end
else
Base.getindex.(A.u, i)
return getindex.(A.u, sym)
end
end
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym,args...) where {T, N}
if issymbollike(sym) && A.syms !== nothing
i = findfirst(isequal(Symbol(sym)),A.syms)
else
i = sym
end

if i === nothing
if issymbollike(sym) && A.indepsym !== nothing && Symbol(sym) == A.indepsym
A.t[args...]
if issymbollike(sym) && !isnothing(A.sc)
if is_indep_sym(A.sc, sym)
return A.t[args...]
elseif is_state_sym(A.sc, sym)
return A[sym][args...]
else
observed(A,sym,args...)
return observed(A, sym, args...)
end
elseif all(issymbollike, sym) && !isnothing(A.sc)
return reduce(vcat, map(s -> A[s, args...]', sym))
else
Base.getindex.(A.u, i, args...)
return getindex.(A.u, sym)
end
end
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, I::Int...) where {T, N} = A.u[I[end]][Base.front(I)...]
Expand Down Expand Up @@ -230,8 +243,7 @@ tuples(VA::DiffEqArray) = tuple.(VA.t,VA.u)
Base.copy(VA::AbstractDiffEqArray) = typeof(VA)(
copy(VA.u),
copy(VA.t),
(VA.syms===nothing) ? nothing : copy(VA.syms),
(VA.indepsym===nothing) ? nothing : copy(VA.indepsym),
(VA.sc===nothing) ? nothing : copy(VA.sc),
(VA.observed===nothing) ? nothing : copy(VA.observed),
(VA.p===nothing) ? nothing : copy(VA.p)
)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ if GROUP == "Core" || GROUP == "All"
@time @testset "Utils Tests" begin include("utils_test.jl") end
@time @testset "Partitions Tests" begin include("partitions_test.jl") end
@time @testset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end
@time @testset "SymbolicIndexingInterface API test" begin include("symbolic_indexing_interface_test.jl") end
@time @testset "VecOfArr Interface Tests" begin include("interface_tests.jl") end
@time @testset "Table traits" begin include("tabletraits.jl") end
@time @testset "StaticArrays Tests" begin include("copy_static_array_test.jl") end
Expand Down
15 changes: 15 additions & 0 deletions test/symbolic_indexing_interface_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using RecursiveArrayTools, Test

t = 0.0:0.1:1.0
f(x) = 2x
f2(x) = 3x

dx = DiffEqArray([[f(x), f2(x)] for x in t], t, [:a, :b], :t)
@test dx[:t] == t
@test dx[:a] == [f(x) for x in t]
@test dx[:b] == [f2(x) for x in t]

dx = DiffEqArray([[f(x), f2(x)] for x in t], t, [:a, :b], [:t])
@test dx[:t] == t
dx = DiffEqArray([[f(x), f2(x)] for x in t], t, [:a, :b])
@test_throws Exception dx[nothing] # make sure it isn't storing [nothing] as indepsym