Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.13.0"
version = "0.13.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
31 changes: 30 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,35 @@ left-hand side of a `.~` expression such as `x .~ Normal()`.

This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the
variables.

# Example
```jldoctest; setup=:(using Distributions)
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(1, 1.0), randn(1, 2), @varname(x)); vns
2-element Vector{VarName{:x, Tuple{Tuple{Colon, Int64}}}}:
x[:,1]
x[:,2]

julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); vns
1×2 Matrix{VarName{:x, Tuple{Tuple{Colon}, Tuple{Int64, Int64}}}}:
x[:][1,1] x[:][1,2]

julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); vns
3-element Vector{VarName{:x, Tuple{Tuple{Int64}, Tuple{Int64}}}}:
x[1][1]
x[1][2]
x[1][3]

julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2, 3), @varname(x)); vns
1×2×3 Array{VarName{:x, Tuple{Tuple{Int64, Int64, Int64}}}, 3}:
[:, :, 1] =
x[1,1,1] x[1,2,1]

[:, :, 2] =
x[1,1,2] x[1,2,2]

[:, :, 3] =
x[1,1,3] x[1,2,3]
```
"""
unwrap_right_left_vns(right, left, vns) = right, left, vns
function unwrap_right_left_vns(right::NamedDist, left, vns)
Expand All @@ -103,7 +132,7 @@ function unwrap_right_left_vns(
# for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`,
# and we therefore add the `Colon()` below.
vns = map(axes(left, 2)) do i
return VarName(vn, (vn.indexing..., Colon(), Tuple(i)))
return VarName(vn, (vn.indexing..., (Colon(), i)))
end
return unwrap_right_left_vns(right, left, vns)
end
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractMCMC = "2.1, 3.0"
AbstractPPL = "0.1.4, 0.2"
AbstractPPL = "0.2"
Bijectors = "0.9.5"
Distributions = "< 0.25.11"
DistributionsAD = "0.6.3"
Expand Down