Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/src/internals/varnamedtuple.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ For instance, if `setindex!!(vnt, @varname(a[1:5]), val)` has been set, then the
Not `@varname(a[1:10])`, nor `@varname(a[3])`, nor for anything else that overlaps with `@varname(a[1:5])`.
`haskey` likewise only returns true for `@varname(a[1:5])`, and `keys(vnt)` only has that as an element.

The size of a value, for the purposes of inserting it into a `PartialArray`, is determined by a call to `vnt_size`.
`vnt_size` falls back to calling `Base.size`.
The reason we define a distinct function is to be able to control its behaviour, if necessary, without type piracy.

## Limitations

This design has a several of benefits, for performance and generality, but it also has limitations:
Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ abstract type AbstractVarInfo <: AbstractModelTrace end
# Necessary forward declarations
include("utils.jl")
include("varnamedtuple.jl")
using .VarNamedTuples: VarNamedTuple
using .VarNamedTuples: VarNamedTuples, VarNamedTuple
include("contexts.jl")
include("contexts/default.jl")
include("contexts/init.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ struct RangeAndLinked{T<:Tuple}
original_size::T
end

Base.size(ral::RangeAndLinked) = ral.original_size
VarNamedTuples.vnt_size(ral::RangeAndLinked) = ral.original_size

"""
VectorWithRanges{Tlink}(
Expand Down
32 changes: 22 additions & 10 deletions src/varnamedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using BangBang
using Accessors
using ..DynamicPPL: _compose_no_identity

export VarNamedTuple
export VarNamedTuple, vnt_size

# We define our own getindex, setindex!!, and haskey functions, which we use to
# get/set/check values in VarNamedTuple and PartialArray. We do this because we want to be
Expand Down Expand Up @@ -81,6 +81,17 @@ const INDEX_TYPES = Union{Integer,AbstractUnitRange,Colon,AbstractPPL.Concretize
_unwrap_concretized_slice(cs::AbstractPPL.ConcretizedSlice) = cs.range
_unwrap_concretized_slice(x::Union{Integer,AbstractUnitRange,Colon}) = x

"""
vnt_size(x)

Get the size of an object `x` for use in `VarNamedTuple` and `PartialArray`.

By default, this falls back onto `Base.size`, but can be overloaded for custom types.
This notion of type is used to determine whether a value can be set into a `PartialArray`
as a block, see the docstring of `PartialArray` and `ArrayLikeBlock` for details.
"""
vnt_size(x) = size(x)

"""
ArrayLikeBlock{T,I}

Expand Down Expand Up @@ -156,11 +167,12 @@ Like `Base.Array`s, `PartialArray`s have a well-defined, compile-time-known elem

One can set values in a `PartialArray` either element-by-element, or with ranges like
`arr[1:3,2] = [5,10,15]`. When setting values over a range of indices, the value being set
must either be an `AbstractArray` or otherwise something for which `size(value)` is defined,
and the size mathces the range. If the value is an `AbstractArray`, the elements are copied
individually, but if it is not, the value is stored as a block, that takes up the whole
range, e.g. `[1:3,2]`, but is only a single object. Getting such a block-value must be done
with the exact same range of indices, otherwise an error is thrown.
must either be an `AbstractArray` or otherwise something for which `vnt_size(value)` or
`Base.size(value)` (which `vnt_size` falls back onto) is defined, and the size matches the
range. If the value is an `AbstractArray`, the elements are copied individually, but if it
is not, the value is stored as a block, that takes up the whole range, e.g. `[1:3,2]`, but
is only a single object. Getting such a block-value must be done with the exact same range
of indices, otherwise an error is thrown.

If the element type of a `PartialArray` is not concrete, any call to `setindex!!` will check
if, after the new value has been set, the element type can be made more concrete. If so,
Expand Down Expand Up @@ -596,7 +608,7 @@ The value only depends on the types of the arguments, and should be constant pro
function _needs_arraylikeblock(value, inds::Vararg{INDEX_TYPES})
return _is_multiindex(inds) &&
!isa(value, AbstractArray) &&
hasmethod(size, Tuple{typeof(value)})
hasmethod(vnt_size, Tuple{typeof(value)})
end

function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES})
Expand All @@ -612,11 +624,11 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES})
new_data = pa.data
if _needs_arraylikeblock(value, inds...)
inds_size = reduce((x, y) -> tuple(x..., y...), map(size, inds))
if size(value) != inds_size
if vnt_size(value) != inds_size
throw(
DimensionMismatch(
"Assigned value has size $(size(value)), which does not match the " *
"size implied by the indices $(map(x -> _length_needed(x), inds)).",
"Assigned value has size $(vnt_size(value)), which does not match " *
"the size implied by the indices $(map(x -> _length_needed(x), inds)).",
),
)
end
Expand Down
Loading