diff --git a/docs/src/api.md b/docs/src/api.md index 20eb1ce35..f687fd90a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -377,6 +377,7 @@ SimpleVarInfo ```@docs DynamicPPL.VarNamedTuples.VarNamedTuple +DynamicPPL.VarNamedTuples.vnt_size ``` ### Accumulators diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index aa08c119d..daa062d2d 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -168,6 +168,10 @@ For instance, if `setindex!!(vnt, @varname(a[1:5]), val)` has been set, then the Not `@varname(a[1:10])`, nor `@varname(a[3])`, nor for anything else that overlaps with `@varname(a[1:5])`. `haskey` likewise only returns true for `@varname(a[1:5])`, and `keys(vnt)` only has that as an element. +The size of a value, for the purposes of inserting it into a `PartialArray`, is determined by a call to `vnt_size`. +`vnt_size` falls back to calling `Base.size`. +The reason we define a distinct function is to be able to control its behaviour, if necessary, without type piracy. + ## Limitations This design has a several of benefits, for performance and generality, but it also has limitations: diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 25ca59018..95831062f 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -185,7 +185,7 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") include("varnamedtuple.jl") -using .VarNamedTuples: VarNamedTuple +using .VarNamedTuples: VarNamedTuples, VarNamedTuple include("contexts.jl") include("contexts/default.jl") include("contexts/init.jl") diff --git a/src/contexts/init.jl b/src/contexts/init.jl index dd9e99421..f5259f1cd 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -215,7 +215,7 @@ struct RangeAndLinked{T<:Tuple} original_size::T end -Base.size(ral::RangeAndLinked) = ral.original_size +VarNamedTuples.vnt_size(ral::RangeAndLinked) = ral.original_size """ VectorWithRanges{Tlink}( diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 0346ec6e6..4e99aa1ec 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -8,7 +8,7 @@ using BangBang using Accessors using ..DynamicPPL: _compose_no_identity -export VarNamedTuple +export VarNamedTuple, vnt_size # We define our own getindex, setindex!!, and haskey functions, which we use to # get/set/check values in VarNamedTuple and PartialArray. We do this because we want to be @@ -81,6 +81,17 @@ const INDEX_TYPES = Union{Integer,AbstractUnitRange,Colon,AbstractPPL.Concretize _unwrap_concretized_slice(cs::AbstractPPL.ConcretizedSlice) = cs.range _unwrap_concretized_slice(x::Union{Integer,AbstractUnitRange,Colon}) = x +""" + vnt_size(x) + +Get the size of an object `x` for use in `VarNamedTuple` and `PartialArray`. + +By default, this falls back onto `Base.size`, but can be overloaded for custom types. +This notion of type is used to determine whether a value can be set into a `PartialArray` +as a block, see the docstring of `PartialArray` and `ArrayLikeBlock` for details. +""" +vnt_size(x) = size(x) + """ ArrayLikeBlock{T,I} @@ -156,11 +167,12 @@ Like `Base.Array`s, `PartialArray`s have a well-defined, compile-time-known elem One can set values in a `PartialArray` either element-by-element, or with ranges like `arr[1:3,2] = [5,10,15]`. When setting values over a range of indices, the value being set -must either be an `AbstractArray` or otherwise something for which `size(value)` is defined, -and the size mathces the range. If the value is an `AbstractArray`, the elements are copied -individually, but if it is not, the value is stored as a block, that takes up the whole -range, e.g. `[1:3,2]`, but is only a single object. Getting such a block-value must be done -with the exact same range of indices, otherwise an error is thrown. +must either be an `AbstractArray` or otherwise something for which `vnt_size(value)` or +`Base.size(value)` (which `vnt_size` falls back onto) is defined, and the size matches the +range. If the value is an `AbstractArray`, the elements are copied individually, but if it +is not, the value is stored as a block, that takes up the whole range, e.g. `[1:3,2]`, but +is only a single object. Getting such a block-value must be done with the exact same range +of indices, otherwise an error is thrown. If the element type of a `PartialArray` is not concrete, any call to `setindex!!` will check if, after the new value has been set, the element type can be made more concrete. If so, @@ -596,7 +608,7 @@ The value only depends on the types of the arguments, and should be constant pro function _needs_arraylikeblock(value, inds::Vararg{INDEX_TYPES}) return _is_multiindex(inds) && !isa(value, AbstractArray) && - hasmethod(size, Tuple{typeof(value)}) + hasmethod(vnt_size, Tuple{typeof(value)}) end function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) @@ -612,11 +624,11 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) new_data = pa.data if _needs_arraylikeblock(value, inds...) inds_size = reduce((x, y) -> tuple(x..., y...), map(size, inds)) - if size(value) != inds_size + if vnt_size(value) != inds_size throw( DimensionMismatch( - "Assigned value has size $(size(value)), which does not match the " * - "size implied by the indices $(map(x -> _length_needed(x), inds)).", + "Assigned value has size $(vnt_size(value)), which does not match " * + "the size implied by the indices $(map(x -> _length_needed(x), inds)).", ), ) end