-
Notifications
You must be signed in to change notification settings - Fork 37
Faster evaluation: SimpleVarInfo
#267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
208b62c
a8e55bd
8ea80d7
d317bd8
a88f8ea
bfd7c78
acb15eb
4828aab
b56024e
e4f0ad2
ccfd112
d660433
c925b07
3ec72c6
90cf754
0ab9d8b
42ad552
975184d
a0cd0c4
744a032
76daca6
6f947f7
4076f63
d0a08f6
4002318
ff75ddc
d29dd8f
a72594f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,182 @@ | ||
| using Setfield | ||
|
|
||
| """ | ||
| SimpleVarInfo{NT,T} <: AbstractVarInfo | ||
|
|
||
| A simple wrapper of the parameters with a `logp` field for | ||
| accumulation of the logdensity. | ||
|
|
||
| Currently only implemented for `NT <: NamedTuple`. | ||
|
|
||
| ## Notes | ||
| The major differences between this and `TypedVarInfo` are: | ||
| 1. `SimpleVarInfo` does not require linearization. | ||
| 2. `SimpleVarInfo` can use more efficient bijectors. | ||
| 3. `SimpleVarInfo` only supports evaluation. | ||
| """ | ||
| struct SimpleVarInfo{NT,T} <: AbstractVarInfo | ||
| θ::NT | ||
| logp::T | ||
| end | ||
|
|
||
| SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) | ||
| SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) | ||
| SimpleVarInfo{T}() where {T<:Real} = SimpleVarInfo{T}(nothing) | ||
| SimpleVarInfo() = SimpleVarInfo{Float64}() | ||
|
|
||
| getlogp(vi::SimpleVarInfo) = vi.logp | ||
| setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) | ||
| acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) | ||
|
|
||
| function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) | ||
| vi.logp[] = logp | ||
| return vi | ||
| end | ||
|
|
||
| function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) | ||
| vi.logp[] += logp | ||
| return vi | ||
| end | ||
|
|
||
| function _getvalue(nt::NamedTuple, ::Val{sym}, inds=()) where {sym} | ||
| # Use `getproperty` instead of `getfield` | ||
| value = getproperty(nt, sym) | ||
| return _getindex(value, inds) | ||
| end | ||
|
|
||
| function getval(vi::SimpleVarInfo, vn::VarName{sym}) where {sym} | ||
| return _getvalue(vi.θ, Val{sym}(), vn.indexing) | ||
| end | ||
| # `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than | ||
| # just `Vector`. | ||
| getval(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = map(vn -> getval(vi, vn), vns) | ||
| # To disambiguiate. | ||
| getval(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(vn -> getval(vi, vn), vns) | ||
|
|
||
| haskey(vi::SimpleVarInfo, vn) = haskey(vi.θ, getsym(vn)) | ||
|
|
||
| istrans(::SimpleVarInfo, vn::VarName) = false | ||
|
|
||
| getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.θ | ||
| getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.θ | ||
| # TODO: Should we do better? | ||
| getindex(vi::SimpleVarInfo, spl::Sampler) = vi.θ | ||
| getindex(vi::SimpleVarInfo, vn::VarName) = getval(vi, vn) | ||
| getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns) | ||
| # HACK: Need to disambiguiate. | ||
| getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns) | ||
|
|
||
| # Necessary for `matchingvalue` to work properly. | ||
| function Base.eltype( | ||
| vi::SimpleVarInfo{<:Any,T}, spl::Union{AbstractSampler,SampleFromPrior} | ||
| ) where {T} | ||
| return T | ||
| end | ||
|
|
||
| function push!!( | ||
| vi::SimpleVarInfo{Nothing}, vn::VarName{sym,Tuple{}}, value, dist::Distribution | ||
| ) where {sym} | ||
| @set vi.θ = NamedTuple{(sym,)}((value,)) | ||
| end | ||
| function push!!( | ||
| vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym,Tuple{}}, value, dist::Distribution | ||
| ) where {sym} | ||
| @set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) | ||
| end | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any good ideas of how we can support cases where we actually have indexing, i.e. Might be useful to exploit
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess everyone has already had their own ideas about "trie-style" VarInfo, but mine was the following (similar to your views idea, I think): store the complete data in one big array (or perhaps one per type). Then put a trie-like thing on top of that, referencing only views into it. So we'd have something like, conceptually, backup = [1.0, 2.0, 3.0, 4.0]
vi = (x = @view(backup, 1:2),
y = (var"1" = @view(backup, 3),
var"3" = @view(backup, 4)))Of course there's a lot of possibilities of implementing this better (like a dictionary for
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The issue is that But for Now that everything is a
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another thing: nested indexing is kind of "useless" if we have something like #275 😕 |
||
|
|
||
| # Context implementations | ||
| function tilde_assume!!(context, right, vn, inds, vi::SimpleVarInfo) | ||
| value, logp, vi_new = tilde_assume(context, right, vn, inds, vi) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Notice how here we also take the Ideally this is what we should be doing overall if we're going to support immutable varinfos. But if we force this, we'll end up breaking a lot of downstream samplers since it requires changing As an intermediate step, it might be worth just overloading |
||
| return value, acclogp!!(vi_new, logp) | ||
| end | ||
|
|
||
| function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo) | ||
| left = vi[vn] | ||
| return left, Distributions.loglikelihood(dist, left), vi | ||
| end | ||
|
|
||
| function assume( | ||
| rng::Random.AbstractRNG, | ||
| sampler::SampleFromPrior, | ||
| dist::Distribution, | ||
| vn::VarName, | ||
| vi::SimpleVarInfo, | ||
| ) | ||
| value = init(rng, dist, sampler) | ||
| vi = push!!(vi, vn, value, dist, sampler) | ||
| vi = settrans!!(vi, false, vn) | ||
| return value, Distributions.loglikelihood(dist, value), vi | ||
| end | ||
|
|
||
| # function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo) | ||
| # throw(MethodError(dot_tilde_assume!!, (context, right, left, vn, inds, vi))) | ||
| # end | ||
|
|
||
| function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo) | ||
| value, logp, vi_new = dot_tilde_assume(context, right, left, vn, inds, vi) | ||
| # Mutation of `value` no longer occurs in main body, so we do it here. | ||
| left .= value | ||
| return value, acclogp!!(vi_new, logp) | ||
| end | ||
|
|
||
| function dot_assume( | ||
| dist::MultivariateDistribution, | ||
| var::AbstractMatrix, | ||
| vns::AbstractVector{<:VarName}, | ||
| vi::SimpleVarInfo, | ||
| ) | ||
| @assert length(dist) == size(var, 1) | ||
| # NOTE: We cannot work with `var` here because we might have a model of the form | ||
| # | ||
| # m = Vector{Float64}(undef, n) | ||
| # m .~ Normal() | ||
| # | ||
| # in which case `var` will have `undef` elements, even if `m` is present in `vi`. | ||
| value = vi[vns] | ||
| lp = sum(zip(vns, eachcol(value))) do vn, val | ||
| return Distributions.logpdf(dist, val) | ||
| end | ||
| return value, lp, vi | ||
| end | ||
|
|
||
| function dot_assume( | ||
| dists::Union{Distribution,AbstractArray{<:Distribution}}, | ||
| var::AbstractArray, | ||
| vns::AbstractArray{<:VarName}, | ||
| vi::SimpleVarInfo{<:NamedTuple}, | ||
| ) | ||
| # NOTE: We cannot work with `var` here because we might have a model of the form | ||
| # | ||
| # m = Vector{Float64}(undef, n) | ||
| # m .~ Normal() | ||
| # | ||
| # in which case `var` will have `undef` elements, even if `m` is present in `vi`. | ||
| value = vi[vns] | ||
| lp = sum(Distributions.logpdf.(dists, value)) | ||
| return value, lp, vi | ||
| end | ||
|
|
||
| # HACK: Allows us to re-use the impleemntation of `dot_tilde`, etc. for literals. | ||
| increment_num_produce!(::SimpleVarInfo) = nothing | ||
| settrans!!(vi::SimpleVarInfo, trans::Bool, vn::VarName) = vi | ||
|
|
||
| # Interaction with `VarInfo` | ||
| SimpleVarInfo(vi::TypedVarInfo) = SimpleVarInfo{eltype(getlogp(vi))}(vi) | ||
| function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real,names} | ||
| vals = map(names) do n | ||
| let md = getfield(vi.metadata, n) | ||
| x = map(enumerate(md.ranges)) do (i, r) | ||
| reconstruct(md.dists[i], md.vals[r]) | ||
| end | ||
|
|
||
| # TODO: Doesn't support batches of `MultivariateDistribution`? | ||
| length(x) == 1 ? x[1] : x | ||
|
||
| end | ||
| end | ||
|
|
||
| return SimpleVarInfo{T}(NamedTuple{names}(vals)) | ||
| end | ||
|
|
||
| function SimpleVarInfo(model::Model, args...) | ||
| return SimpleVarInfo(VarInfo(Random.GLOBAL_RNG, model, args...)) | ||
| end | ||
Uh oh!
There was an error while loading. Please reload this page.