Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ SimpleVarInfo

```@docs
DynamicPPL.VarNamedTuples.VarNamedTuple
DynamicPPL.VarNamedTuples.vnt_size
```

### Accumulators
Expand Down
4 changes: 4 additions & 0 deletions docs/src/internals/varnamedtuple.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ For instance, if `setindex!!(vnt, @varname(a[1:5]), val)` has been set, then the
Not `@varname(a[1:10])`, nor `@varname(a[3])`, nor for anything else that overlaps with `@varname(a[1:5])`.
`haskey` likewise only returns true for `@varname(a[1:5])`, and `keys(vnt)` only has that as an element.

The size of a value, for the purposes of inserting it into a `PartialArray`, is determined by a call to `vnt_size`.
`vnt_size` falls back to calling `Base.size`.
The reason we define a distinct function is to be able to control its behaviour, if necessary, without type piracy.

## Limitations

This design has a several of benefits, for performance and generality, but it also has limitations:
Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ abstract type AbstractVarInfo <: AbstractModelTrace end
# Necessary forward declarations
include("utils.jl")
include("varnamedtuple.jl")
using .VarNamedTuples: VarNamedTuple
using .VarNamedTuples: VarNamedTuples, VarNamedTuple
include("contexts.jl")
include("contexts/default.jl")
include("contexts/init.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ struct RangeAndLinked{T<:Tuple}
original_size::T
end

Base.size(ral::RangeAndLinked) = ral.original_size
VarNamedTuples.vnt_size(ral::RangeAndLinked) = ral.original_size

"""
VectorWithRanges{Tlink}(
Expand Down
32 changes: 22 additions & 10 deletions src/varnamedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using BangBang
using Accessors
using ..DynamicPPL: _compose_no_identity

export VarNamedTuple
export VarNamedTuple, vnt_size

# We define our own getindex, setindex!!, and haskey functions, which we use to
# get/set/check values in VarNamedTuple and PartialArray. We do this because we want to be
Expand Down Expand Up @@ -81,6 +81,17 @@ const INDEX_TYPES = Union{Integer,AbstractUnitRange,Colon,AbstractPPL.Concretize
_unwrap_concretized_slice(cs::AbstractPPL.ConcretizedSlice) = cs.range
_unwrap_concretized_slice(x::Union{Integer,AbstractUnitRange,Colon}) = x

"""
vnt_size(x)

Get the size of an object `x` for use in `VarNamedTuple` and `PartialArray`.

By default, this falls back onto `Base.size`, but can be overloaded for custom types.
This notion of type is used to determine whether a value can be set into a `PartialArray`
as a block, see the docstring of `PartialArray` and `ArrayLikeBlock` for details.
"""
vnt_size(x) = size(x)

"""
ArrayLikeBlock{T,I}

Expand Down Expand Up @@ -156,11 +167,12 @@ Like `Base.Array`s, `PartialArray`s have a well-defined, compile-time-known elem

One can set values in a `PartialArray` either element-by-element, or with ranges like
`arr[1:3,2] = [5,10,15]`. When setting values over a range of indices, the value being set
must either be an `AbstractArray` or otherwise something for which `size(value)` is defined,
and the size mathces the range. If the value is an `AbstractArray`, the elements are copied
individually, but if it is not, the value is stored as a block, that takes up the whole
range, e.g. `[1:3,2]`, but is only a single object. Getting such a block-value must be done
with the exact same range of indices, otherwise an error is thrown.
must either be an `AbstractArray` or otherwise something for which `vnt_size(value)` or
`Base.size(value)` (which `vnt_size` falls back onto) is defined, and the size matches the
range. If the value is an `AbstractArray`, the elements are copied individually, but if it
is not, the value is stored as a block, that takes up the whole range, e.g. `[1:3,2]`, but
is only a single object. Getting such a block-value must be done with the exact same range
of indices, otherwise an error is thrown.

If the element type of a `PartialArray` is not concrete, any call to `setindex!!` will check
if, after the new value has been set, the element type can be made more concrete. If so,
Expand Down Expand Up @@ -596,7 +608,7 @@ The value only depends on the types of the arguments, and should be constant pro
function _needs_arraylikeblock(value, inds::Vararg{INDEX_TYPES})
return _is_multiindex(inds) &&
!isa(value, AbstractArray) &&
hasmethod(size, Tuple{typeof(value)})
hasmethod(vnt_size, Tuple{typeof(value)})
end

function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES})
Expand All @@ -612,11 +624,11 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES})
new_data = pa.data
if _needs_arraylikeblock(value, inds...)
inds_size = reduce((x, y) -> tuple(x..., y...), map(size, inds))
if size(value) != inds_size
if vnt_size(value) != inds_size
throw(
DimensionMismatch(
"Assigned value has size $(size(value)), which does not match the " *
"size implied by the indices $(map(x -> _length_needed(x), inds)).",
"Assigned value has size $(vnt_size(value)), which does not match " *
"the size implied by the indices $(map(x -> _length_needed(x), inds)).",
),
)
end
Expand Down