Skip to content
6 changes: 6 additions & 0 deletions src/axis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ ViewAxis{Inds,IdxMap,Ax}() where {Inds,IdxMap,Ax} = ViewAxis(Inds, Ax())
ViewAxis(Inds, IdxMap) = ViewAxis(Inds, Axis(IdxMap))
ViewAxis(Inds) = Inds

Base.length(ax::ViewAxis{Inds}) where Inds = length(Inds)
# Fix https://github.com/Deltares/Ribasim/issues/2028
Base.getindex(::ViewAxis{Inds, IdxMap, <:ComponentArrays.Shaped1DAxis}, idx::Integer) where {Inds,IdxMap} = Inds[idx]
Base.iterate(::ViewAxis{Inds, IdxMap, <:ComponentArrays.Shaped1DAxis}) where {Inds,IdxMap} = iterate(Inds)
Base.iterate(::ViewAxis{Inds, IdxMap, <:ComponentArrays.Shaped1DAxis}, idx) where {Inds,IdxMap} = iterate(Inds, idx)

const View = ViewAxis
const NullOrFlatView{Inds,IdxMap} = ViewAxis{Inds,IdxMap,<:NullorFlatAxis}

Expand Down
2 changes: 2 additions & 0 deletions src/componentindex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ const NullComponentIndex{Idx} = ComponentIndex{Idx, NullAxis}

Base.:(==)(ci1::ComponentIndex, ci2::ComponentIndex) = ci1.idx == ci2.idx && ci1.ax == ci2.ax

Base.length(ci::ComponentIndex) = length(ci.idx)


"""
KeepIndex(idx)
Expand Down
51 changes: 51 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ end
x = ComponentArray(b=1, a=2)
@test merge(NamedTuple(), x) == NamedTuple(x)
@test kw_fun(; x...) == 2

@test length(ViewAxis(2:7, ShapedAxis((2,3)))) == 6
end

@testset "Get" begin
Expand Down Expand Up @@ -385,6 +387,12 @@ end
@test ax[(:a, :c)] == ax[[:a, :c]] == ComponentArrays.ComponentIndex([1, 3, 4], Axis(a = 1, c = r2v(2:3)))
ax2 = getaxes(ca2)[1]
@test ax2[(:a, :c)] == ax2[[:a, :c]] == ComponentArrays.ComponentIndex([1, 3:8...], Axis(a = 1, c = ViewAxis(2:7, ShapedAxis((2,3)))))

@test length(ComponentArrays.ComponentIndex(1, ComponentArrays.NullAxis())) == 1
@test length(ComponentArrays.ComponentIndex(3:4, ShapedAxis(size(3:4)))) == 2
@test length(ComponentArrays.ComponentIndex(5:8, Axis(a = r2v(1:3), b = 4))) == 4
@test length(ComponentArrays.ComponentIndex([1, 3, 4], Axis(a = 1, c = r2v(2:3)))) == 3
@test length(ComponentArrays.ComponentIndex([1, 3:8...], Axis(a = 1, c = ViewAxis(2:7, ShapedAxis((2,3)))))) == 7
end

@testset "KeepIndex" begin
Expand Down Expand Up @@ -843,6 +851,49 @@ end
@test all(Xstack4_dcolon[:a, :, :] .== Xstack4_noca_dcolon[1, :, :])
@test all(Xstack4_dcolon[:b, :, :] .== Xstack4_noca_dcolon[2:3, :, :])
end

# Test fix https://github.com/Deltares/Ribasim/issues/2028
a = range(0.0, 1.0, length=0) |> collect
b = range(0.0, 1.0; length=2) |> collect
c = range(0.0, 1.0, length=3) |> collect
d = range(0.0, 1.0; length=0) |> collect
u = ComponentVector(a=a, b=b, c=c, d=d)

function get_state_index(
idx::Int,
::ComponentVector{A, B, <:Tuple{<:Axis{NT}}},
component_name::Symbol
) where {A, B, NT}
for (comp, range) in pairs(NT)
if comp == component_name
return range[idx]
end
end
return nothing
end

@test_throws BoundsError get_state_index(1, u, :a)
@test_throws BoundsError get_state_index(2, u, :a)
@test get_state_index(1, u, :b) == 1
@test get_state_index(2, u, :b) == 2
@test get_state_index(1, u, :c) == 3
@test get_state_index(2, u, :c) == 4
@test get_state_index(3, u, :c) == 5
@test_throws BoundsError get_state_index(1, u, :d)
@test_throws BoundsError get_state_index(2, u, :d)

# Must be a better way to make sure we can `Base.iterate` the `ViewAxis{UnitRange, Shaped1DAxis}`.
nt = ComponentArrays.indexmap(getaxes(u)[1])
for (i, idx) in enumerate(nt.a)
end
for (i, idx) in enumerate(nt.b)
@test idx == i
end
for (i, idx) in enumerate(nt.c)
@test idx == i + 2
end
for (i, idx) in enumerate(nt.d)
end
end

@testset "axpy! / axpby!" begin
Expand Down
Loading