Skip to content
Closed
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Expand All @@ -19,5 +20,6 @@ Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9"
ChainRulesCore = "0.9.7, 0.10"
Distributions = "0.23.8, 0.24, 0.25"
MacroTools = "0.5.6"
Symbolics = "1"
ZygoteRules = "0.2"
julia = "1.3"
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ include("compiler.jl")
include("prob_macro.jl")
include("compat/ad.jl")
include("loglikelihoods.jl")
include("symbolic/Symbolic.jl")
include("submodel_macro.jl")

end # module
77 changes: 77 additions & 0 deletions src/symbolic/Symbolic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
module Symbolic

import ..DynamicPPL
import ..DynamicPPL: Model, VarInfo, AbstractSampler, SampleFromPrior, VarName, DefaultContext

import Random
import Bijectors
using Distributions
import Symbolics
import Symbolics: SymbolicUtils

issym(x::Union{Symbolics.Num, SymbolicUtils.Symbolic}) = true
issym(x) = false

include("rules.jl")
include("contexts.jl")

symbolize(args...; kwargs...) = symbolize(Random.GLOBAL_RNG, args...; kwargs...)
function symbolize(
rng::Random.AbstractRNG,
m::Model,
vi::VarInfo = VarInfo(m);
spl = SampleFromPrior(),
ctx = DefaultContext(),
include_data = false
)
m(rng, vi, spl, ctx);
θ_orig = vi[spl]

# Symbolic `logpdf` for fixed observations.
Symbolics.@variables θ[1:length(θ_orig)]
vi = VarInfo(vi, spl, θ, zero(eltype(θ)));
m(rng, vi, spl, ctx);

return vi, θ
end

function dependencies(ctx::SymbolicContext, vn::VarName)
right = ctx.vn2rights[vn]
r = Symbolics.value(right)

if !issym(r)
# No dependencies.
return []
end

args = SymbolicUtils.arguments(r)
return mapreduce(vcat, args) do a
Symbolics.get_variables(a)
end
end
function dependencies(ctx::SymbolicContext, symbolic = false)
vn2var = ctx.vn2var
var2vn = Dict(values(vn2var) .=> keys(vn2var))
return Dict(
(symbolic ? vn2var[vn] : vn) => map(x -> symbolic ? x : var2vn[x], dependencies(ctx, vn))
for vn in keys(ctx.vn2var)
)
end

function dependencies(m::Model, symbolic = false)
ctx = SymbolicContext(DefaultContext())
vi = symbolize(m, VarInfo(m), ctx = ctx)

return dependencies(ctx, symbolic)
end


function symbolic_logp(m::Model)
vi, θ = symbolize(m)
lp = DynamicPPL.getlogp(vi)
lp_analytic = analytic_rw(Symbolics.value(lp))
lp_analytic_num = addnum_rw(lp_analytic)

return lp_analytic_num, θ
end
end
29 changes: 29 additions & 0 deletions src/symbolic/contexts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
struct SymbolicContext{Ctx} <: DynamicPPL.AbstractContext
ctx::Ctx
vn2var::Dict
vn2rights::Dict
end
SymbolicContext() = SymbolicContext(DefaultContext())
SymbolicContext(ctx) = SymbolicContext(ctx, Dict(), Dict())

# assume
function DynamicPPL.tilde(rng, ctx::SymbolicContext, sampler, right, vn::VarName, inds, vi)
if Symbolic.issym(right) || (haskey(vi, vn) && Symbolic.issym(vi[vn]))
# Distribution is symbolic OR variable is.
ctx.vn2var[vn] = vi[vn]
ctx.vn2rights[vn] = right
end

return DynamicPPL.tilde(rng, ctx.ctx, sampler, right, vn, inds, vi)
end


# TODO: Make it more useful when working with symbolic observations.
# observe
function DynamicPPL.tilde(ctx::SymbolicContext, sampler, right, left, vi)
if Symbolic.issym(right) || Symbolic.issym(left)
# TODO: implement
end

return DynamicPPL.tilde(ctx.ctx, sampler, right, left, vi)
end
72 changes: 72 additions & 0 deletions src/symbolic/rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import Bijectors
import Symbolics
using Symbolics.SymbolicUtils

Symbolics.@register Bijectors.logpdf_with_trans(dist, r, istrans)

# Some predicates
isdist(d) = (d isa Type) && (d <: Distribution)
islogpdf(f::Function) = f === Distributions.logpdf
islogpdf(x) = false

# HACK: Apparently this is needed for disambiguiation.
# TODO: Open issue.
Symbolics.:<ₑ(a::Real, b::Symbolics.Num) = Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b))
Symbolics.:<ₑ(a::Symbolics.Num, b::Real) = Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b))

#############
### Rules ###
#############
# HACK: We'll wrap rewriters to add back `Num`. This way we can get jacobians and whatnot at then end.
const rmnum_rule = @rule (~x) => Symbolics.value(~x)
const addnum_rule = @rule (~x) => Symbolics.Num(~x)

# In the case where we want to work directly with the `x ~ Distribution` statements, the following rules can be useful:
const logpdf_rule = @rule (~x ~ ~d) => Distributions.logpdf(Symbolics.Num(~d), Symbolics.Num(~x));
const rand_rule = @rule (~x ~ ~d) => Distributions.rand(Symbolics.Num(~d))

# We don't want to trace into `Bijectors.logpdf_with_trans`, so we just replace it with `logpdf`.
islogpdf_with_trans(f::Function) = f === Bijectors.logpdf_with_trans
islogpdf_with_trans(x) = false
const logpdf_with_trans_rule = @rule (~f::islogpdf_with_trans)(~dist, ~x, ~istrans) => logpdf(~dist, ~x)

# Attempt to expand `logpdf` to get analytical expressions.
# The idea is that `getlogpdf(d, args)` should return a method of the following signature:
#
# f(args..., x)
#
# which returns the logpdf.
# HACK: this is very hacky but you get the idea
import Distributions: StatsFuns
function getlogpdf(d, args)
replacements = Dict(
:Normal => StatsFuns.normlogpdf,
:Gamma => StatsFuns.gammalogpdf
)

dsym = Symbol(d)
if haskey(replacements, dsym)
return replacements[dsym]
else
return d
end
end

const analytic_rule = @rule (~f::islogpdf)((~d::isdist)(~~args), ~x) => getlogpdf(~d, ~~args)(map(Symbolics.Num, (~~args))..., Symbolics.Num(~x))


#################
### Rewriters ###
#################
# TODO: these should probably be instantiated when needed, rather than here.
const analytic_rw = Rewriters.Postwalk(
Rewriters.Chain((
rmnum_rule, # 0. Remove `Num` so we're only working stuff from `SymbolicUtils.jl`.
logpdf_with_trans_rule, # 1. Replace `logpdf_with_trans` with `logpdf`.
analytic_rule, # 2. Attempt to replace `logpdf` with analytic expression.
))
)

# So we add back `Num` to all terms to allow differentiation.
const rmnum_rw = Rewriters.Postwalk(Rewriters.PassThrough(rmnum_rule))
const addnum_rw = Rewriters.Postwalk(Rewriters.PassThrough(addnum_rule))
5 changes: 5 additions & 0 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector)
)
end

function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector, lp::T) where {T}
md = newmetadata(old_vi.metadata, Val(getspace(spl)), x)
VarInfo(md, Base.RefValue{T}(lp), Ref(get_num_produce(old_vi)))
end

function VarInfo(
rng::Random.AbstractRNG,
model::Model,
Expand Down