Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
208b62c
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Jun 30, 2021
a8e55bd
updated SimpleVarInfo impl
torfjelde Jun 30, 2021
8ea80d7
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Jun 30, 2021
d317bd8
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Jun 30, 2021
a88f8ea
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Jun 30, 2021
bfd7c78
added eltype impl for SimpleVarInfo
torfjelde Jul 2, 2021
acb15eb
formatting
torfjelde Jul 2, 2021
4828aab
fixed eltype for SimpleVarInfo
torfjelde Jul 6, 2021
b56024e
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Jul 9, 2021
e4f0ad2
formatting
torfjelde Jul 9, 2021
ccfd112
initial work on allowing sampling using SimpleVarInfo
torfjelde Jul 9, 2021
d660433
formatting
torfjelde Jul 9, 2021
c925b07
Merge branch 'master' into tor/simple-varinfo-v2
torfjelde Jul 16, 2021
3ec72c6
Merge branch 'tor/simple-varinfo-v2' of github.com:TuringLang/Dynamic…
torfjelde Jul 16, 2021
90cf754
add constructor for SimpleVarInfo using model
torfjelde Jul 16, 2021
0ab9d8b
improved leftover to_namedtuple_expr, fixing a bug when used with Zygote
torfjelde Jul 16, 2021
42ad552
bumped patch version
torfjelde Jul 16, 2021
975184d
Merge branch 'tor/allargs-construction-improvement' into tor/simple-v…
torfjelde Jul 16, 2021
a0cd0c4
Merge branch 'master' into tor/simple-varinfo-v2
torfjelde Jul 19, 2021
744a032
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Jul 19, 2021
76daca6
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Jul 20, 2021
6f947f7
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Jul 22, 2021
4076f63
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Jul 23, 2021
d0a08f6
fixed some issues and added support for usage of Dict in SimpleVarInfo
torfjelde Aug 5, 2021
4002318
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Aug 5, 2021
ff75ddc
added docstring and improved indexing behvaior for SimpleVarInfo
torfjelde Aug 5, 2021
d29dd8f
formatting
torfjelde Aug 5, 2021
a72594f
dont allow sampling with indexing when using SimpleVarInfo with Named…
torfjelde Aug 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export AbstractVarInfo,
VarInfo,
UntypedVarInfo,
TypedVarInfo,
SimpleVarInfo,
push!!,
empty!!,
getlogp,
Expand Down Expand Up @@ -135,6 +136,7 @@ include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("varinfo.jl")
include("simple_varinfo.jl")
include("threadsafe.jl")
include("context_implementations.jl")
include("compiler.jl")
Expand Down
182 changes: 182 additions & 0 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
using Setfield

"""
SimpleVarInfo{NT,T} <: AbstractVarInfo

A simple wrapper of the parameters with a `logp` field for
accumulation of the logdensity.

Currently only implemented for `NT <: NamedTuple`.

## Notes
The major differences between this and `TypedVarInfo` are:
1. `SimpleVarInfo` does not require linearization.
2. `SimpleVarInfo` can use more efficient bijectors.
3. `SimpleVarInfo` only supports evaluation.
"""
struct SimpleVarInfo{NT,T} <: AbstractVarInfo
θ::NT
logp::T
end

SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T))
SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ)
SimpleVarInfo{T}() where {T<:Real} = SimpleVarInfo{T}(nothing)
SimpleVarInfo() = SimpleVarInfo{Float64}()

getlogp(vi::SimpleVarInfo) = vi.logp
setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp)
acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp)

function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp)
vi.logp[] = logp
return vi
end

function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp)
vi.logp[] += logp
return vi
end

function _getvalue(nt::NamedTuple, ::Val{sym}, inds=()) where {sym}
# Use `getproperty` instead of `getfield`
value = getproperty(nt, sym)
return _getindex(value, inds)
end

function getval(vi::SimpleVarInfo, vn::VarName{sym}) where {sym}
return _getvalue(vi.θ, Val{sym}(), vn.indexing)
end
# `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than
# just `Vector`.
getval(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = map(vn -> getval(vi, vn), vns)
# To disambiguiate.
getval(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(vn -> getval(vi, vn), vns)

haskey(vi::SimpleVarInfo, vn) = haskey(vi.θ, getsym(vn))

istrans(::SimpleVarInfo, vn::VarName) = false

getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.θ
getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.θ
# TODO: Should we do better?
getindex(vi::SimpleVarInfo, spl::Sampler) = vi.θ
getindex(vi::SimpleVarInfo, vn::VarName) = getval(vi, vn)
getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns)
# HACK: Need to disambiguiate.
getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns)

# Necessary for `matchingvalue` to work properly.
function Base.eltype(
vi::SimpleVarInfo{<:Any,T}, spl::Union{AbstractSampler,SampleFromPrior}
) where {T}
return T
end

function push!!(
vi::SimpleVarInfo{Nothing}, vn::VarName{sym,Tuple{}}, value, dist::Distribution
) where {sym}
@set vi.θ = NamedTuple{(sym,)}((value,))
end
function push!!(
vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym,Tuple{}}, value, dist::Distribution
) where {sym}
@set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,)))
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any good ideas of how we can support cases where we actually have indexing, i.e. VarName{sym} rather than VarName{sym, Tuple{}}? @devmotion ? 👼

Might be useful to exploit parent if we have #272 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess everyone has already had their own ideas about "trie-style" VarInfo, but mine was the following (similar to your views idea, I think): store the complete data in one big array (or perhaps one per type). Then put a trie-like thing on top of that, referencing only views into it. So we'd have something like, conceptually,

backup = [1.0, 2.0, 3.0, 4.0]
vi = (x = @view(backup, 1:2), 
      y = (var"1" = @view(backup, 3), 
           var"3" = @view(backup, 4)))

Of course there's a lot of possibilities of implementing this better (like a dictionary for y, in case there's a lot of indices.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that var"1" won't be a good idea once the indices are dynamic (this I guess is what you're referring to when saying that it can be implemented using a dictionary).

But for SimpleVarInfo, I'm happy to only support "simple" expressions meaning expressions of the form x and x[...], not x[1][...], or equivalently VarName{<:Any,Tuple{}} and VarName{<:Any, Tuple{<:Tuple}}.

Now that everything is a View, this could be implemented by just preallocating equivalent of parent and the inserting into this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing: nested indexing is kind of "useless" if we have something like #275 😕


# Context implementations
function tilde_assume!!(context, right, vn, inds, vi::SimpleVarInfo)
value, logp, vi_new = tilde_assume(context, right, vn, inds, vi)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice how here we also take the vi_new.

Ideally this is what we should be doing overall if we're going to support immutable varinfos. But if we force this, we'll end up breaking a lot of downstream samplers since it requires changing assume statements to also return vi.

As an intermediate step, it might be worth just overloading tilde_assume!! as I have done here + assume as I have done below. This also brings up another annoyance though: vi should really be at the beginning of the arguments of all the tilde-statements (and ideally assume too, but we can delay this) to minimize method ambiguities but allow us to have some special behavior for the different impls for AbstractVarInfo.

return value, acclogp!!(vi_new, logp)
end

function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo)
left = vi[vn]
return left, Distributions.loglikelihood(dist, left), vi
end

function assume(
rng::Random.AbstractRNG,
sampler::SampleFromPrior,
dist::Distribution,
vn::VarName,
vi::SimpleVarInfo,
)
value = init(rng, dist, sampler)
vi = push!!(vi, vn, value, dist, sampler)
vi = settrans!!(vi, false, vn)
return value, Distributions.loglikelihood(dist, value), vi
end

# function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo)
# throw(MethodError(dot_tilde_assume!!, (context, right, left, vn, inds, vi)))
# end

function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo)
value, logp, vi_new = dot_tilde_assume(context, right, left, vn, inds, vi)
# Mutation of `value` no longer occurs in main body, so we do it here.
left .= value
return value, acclogp!!(vi_new, logp)
end

function dot_assume(
dist::MultivariateDistribution,
var::AbstractMatrix,
vns::AbstractVector{<:VarName},
vi::SimpleVarInfo,
)
@assert length(dist) == size(var, 1)
# NOTE: We cannot work with `var` here because we might have a model of the form
#
# m = Vector{Float64}(undef, n)
# m .~ Normal()
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
value = vi[vns]
lp = sum(zip(vns, eachcol(value))) do vn, val
return Distributions.logpdf(dist, val)
end
return value, lp, vi
end

function dot_assume(
dists::Union{Distribution,AbstractArray{<:Distribution}},
var::AbstractArray,
vns::AbstractArray{<:VarName},
vi::SimpleVarInfo{<:NamedTuple},
)
# NOTE: We cannot work with `var` here because we might have a model of the form
#
# m = Vector{Float64}(undef, n)
# m .~ Normal()
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
value = vi[vns]
lp = sum(Distributions.logpdf.(dists, value))
return value, lp, vi
end

# HACK: Allows us to re-use the impleemntation of `dot_tilde`, etc. for literals.
increment_num_produce!(::SimpleVarInfo) = nothing
settrans!!(vi::SimpleVarInfo, trans::Bool, vn::VarName) = vi

# Interaction with `VarInfo`
SimpleVarInfo(vi::TypedVarInfo) = SimpleVarInfo{eltype(getlogp(vi))}(vi)
function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real,names}
vals = map(names) do n
let md = getfield(vi.metadata, n)
x = map(enumerate(md.ranges)) do (i, r)
reconstruct(md.dists[i], md.vals[r])
end

# TODO: Doesn't support batches of `MultivariateDistribution`?
length(x) == 1 ? x[1] : x
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where the indexing goes wrong. Constructing a SimpleVarInfo from TypedVarInfo is somewhat non-trivial if we allow indexing. E.g.

julia> @model function demo3(m)
           m[:, 1] ~ MvNormal(size(m, 1), 1.0)
           m[:, 2] ~ MvNormal(size(m, 1), 1.0)
           return m
       end
demo3 (generic function with 1 method)

julia> setmissings(m::Model, missings...) = Model{missings}(m.name, m.f, m.args, m.defaults);

julia> m3 = setmissings(demo3(rand(3, 2)), :m);

julia> m3()
3×2 Matrix{Float64}:
  0.264148  1.30146
 -0.974804  0.0924204
  1.44424   1.3519

julia> vi = VarInfo(m3);

julia> svi = SimpleVarInfo(vi);

julia> svi.θ.m
2-element Vector{Vector{Float64}}:
 [-0.8340592002751136, 0.5670953725697917, -0.5460275730331128]
 [0.12464384286023088, -1.2118644862064083, 2.3386842884350765]

julia> svi[@varname(m[:, 1])] # (×) since `svi.θ` is vec of vecs, the indexing produces the wrong result
2-element Vector{Vector{Float64}}:
 [-0.8340592002751136, 0.5670953725697917, -0.5460275730331128]
 [0.12464384286023088, -1.2118644862064083, 2.3386842884350765]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we de-linearise the matrix variable m while converting vi::VarInfo to svi::SimpleVarInfo? That is, we store m::Matrix with its original shape in SimpleVarInfo. This de-linearise operation is also needed in MCMCChains when we group matrix variables, I remember.

Copy link
Member Author

@torfjelde torfjelde Jun 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Me knee-jerk reaction is that ordering will be an issue. This is not an issue for VarInfo because there you'll correctly associate vals[1] with vns[1]; it doesn't matter whether vns[1] is actually m[:, 2]. MCMCChains similarly doesn't care about this. In contrast, in SimpleVarInfo we do if we want to ensure that indexing works 😕

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be a better way of doing the indexing though! Not ruling out a solution yet; just saying that it's not that easy I think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g. if we start using view (see #272 ) we could potentially support more scenarios by allocating parent :)

end
end

return SimpleVarInfo{T}(NamedTuple{names}(vals))
end

function SimpleVarInfo(model::Model, args...)
return SimpleVarInfo(VarInfo(Random.GLOBAL_RNG, model, args...))
end