Skip to content

Commit 23141be

Browse files
committed
formatting
1 parent 70c9997 commit 23141be

4 files changed

Lines changed: 39 additions & 35 deletions

File tree

src/symbolic/Symbolic.jl

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
module Symbolic
22

3-
import ..DynamicPPL
4-
import ..DynamicPPL: Model, VarInfo, AbstractSampler, SampleFromPrior, VarName, DefaultContext
3+
using ..DynamicPPL: DynamicPPL
4+
import ..DynamicPPL:
5+
Model, VarInfo, AbstractSampler, SampleFromPrior, VarName, DefaultContext
56

6-
import Random
7-
import Bijectors
7+
using Random: Random
8+
using Bijectors: Bijectors
89
using Distributions
9-
import Symbolics
10+
using Symbolics: Symbolics
1011
import Symbolics: SymbolicUtils
1112

12-
issym(x::Union{Symbolics.Num, SymbolicUtils.Symbolic}) = true
13+
issym(x::Union{Symbolics.Num,SymbolicUtils.Symbolic}) = true
1314
issym(x) = false
1415

1516
include("rules.jl")
@@ -19,18 +20,19 @@ symbolize(args...; kwargs...) = symbolize(Random.GLOBAL_RNG, args...; kwargs...)
1920
function symbolize(
2021
rng::Random.AbstractRNG,
2122
m::Model,
22-
vi::VarInfo = VarInfo(m);
23-
spl = SampleFromPrior(),
24-
ctx = DefaultContext(),
25-
include_data = false
23+
vi::VarInfo=VarInfo(m);
24+
spl=SampleFromPrior(),
25+
ctx=DefaultContext(),
26+
include_data=false,
2627
)
27-
m(rng, vi, spl, ctx);
28+
m(rng, vi, spl, ctx)
2829
θ_orig = vi[spl]
2930

3031
# Symbolic `logpdf` for fixed observations.
32+
# TODO: don't `collect` once symbolic arrays are mature enough.
3133
Symbolics.@variables θ[1:length(θ_orig)]
32-
vi = VarInfo{Real}(vi, spl, θ, 0.0);
33-
m(vi, ctx);
34+
vi = VarInfo{Real}(vi, spl, θ, 0.0)
35+
m(vi, ctx)
3436

3537
return vi, θ
3638
end
@@ -49,23 +51,23 @@ function dependencies(ctx::SymbolicContext, vn::VarName)
4951
Symbolics.get_variables(a)
5052
end
5153
end
52-
function dependencies(ctx::SymbolicContext, symbolic = false)
54+
function dependencies(ctx::SymbolicContext, symbolic=false)
5355
vn2var = ctx.vn2var
5456
var2vn = Dict(values(vn2var) .=> keys(vn2var))
5557
return Dict(
56-
(symbolic ? vn2var[vn] : vn) => map(x -> symbolic ? x : var2vn[x], dependencies(ctx, vn))
57-
for vn in keys(ctx.vn2var)
58+
(symbolic ? vn2var[vn] : vn) =>
59+
map(x -> symbolic ? x : var2vn[x], dependencies(ctx, vn)) for
60+
vn in keys(ctx.vn2var)
5861
)
5962
end
6063

61-
function dependencies(m::Model, symbolic = false)
64+
function dependencies(m::Model, symbolic=false)
6265
ctx = SymbolicContext(DefaultContext())
63-
vi = symbolize(m, VarInfo(m), ctx = ctx)
66+
vi = symbolize(m, VarInfo(m); ctx=ctx)
6467

6568
return dependencies(ctx, symbolic)
6669
end
6770

68-
6971
function symbolic_logp(m::Model)
7072
vi, θ = symbolize(m)
7173
lp = DynamicPPL.getlogp(vi)

src/symbolic/contexts.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ function DynamicPPL.tilde_assume(rng, ctx::SymbolicContext, sampler, right, vn,
1717
return DynamicPPL.tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi)
1818
end
1919

20-
2120
# TODO: Make it more useful when working with symbolic observations.
2221
# observe
2322
function DynamicPPL.tilde_observe(ctx::SymbolicContext, sampler, right, left, vi)

src/symbolic/rules.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import Bijectors
2-
import Symbolics
1+
using Bijectors: Bijectors
2+
using Symbolics: Symbolics
33
using Symbolics.SymbolicUtils
44

55
Symbolics.@register Bijectors.logpdf_with_trans(dist, r, istrans)
@@ -11,8 +11,12 @@ islogpdf(x) = false
1111

1212
# HACK: Apparently this is needed for disambiguiation.
1313
# TODO: Open issue.
14-
Symbolics.:<(a::Real, b::Symbolics.Num) = Symbolics.:<(Symbolics.value(a), Symbolics.value(b))
15-
Symbolics.:<(a::Symbolics.Num, b::Real) = Symbolics.:<(Symbolics.value(a), Symbolics.value(b))
14+
function Symbolics.:<(a::Real, b::Symbolics.Num)
15+
return Symbolics.:<(Symbolics.value(a), Symbolics.value(b))
16+
end
17+
function Symbolics.:<(a::Symbolics.Num, b::Real)
18+
return Symbolics.:<(Symbolics.value(a), Symbolics.value(b))
19+
end
1620

1721
#############
1822
### Rules ###
@@ -22,13 +26,15 @@ const rmnum_rule = @rule (~x) => Symbolics.value(~x)
2226
const addnum_rule = @rule (~x) => Symbolics.Num(~x)
2327

2428
# In the case where we want to work directly with the `x ~ Distribution` statements, the following rules can be useful:
25-
const logpdf_rule = @rule (~x ~ ~d) => Distributions.logpdf(Symbolics.Num(~d), Symbolics.Num(~x));
29+
const logpdf_rule = @rule (~x ~ ~d) =>
30+
Distributions.logpdf(Symbolics.Num(~d), Symbolics.Num(~x));
2631
const rand_rule = @rule (~x ~ ~d) => Distributions.rand(Symbolics.Num(~d))
2732

2833
# We don't want to trace into `Bijectors.logpdf_with_trans`, so we just replace it with `logpdf`.
2934
islogpdf_with_trans(f::Function) = f === Bijectors.logpdf_with_trans
3035
islogpdf_with_trans(x) = false
31-
const logpdf_with_trans_rule = @rule (~f::islogpdf_with_trans)(~dist, ~x, ~istrans) => logpdf(~dist, ~x)
36+
const logpdf_with_trans_rule = @rule (~f::islogpdf_with_trans)(~dist, ~x, ~istrans) =>
37+
logpdf(~dist, ~x)
3238

3339
# Attempt to expand `logpdf` to get analytical expressions.
3440
# The idea is that `getlogpdf(d, args)` should return a method of the following signature:
@@ -39,11 +45,8 @@ const logpdf_with_trans_rule = @rule (~f::islogpdf_with_trans)(~dist, ~x, ~istra
3945
# HACK: this is very hacky but you get the idea
4046
import Distributions: StatsFuns
4147
function getlogpdf(d, args)
42-
replacements = Dict(
43-
:Normal => StatsFuns.normlogpdf,
44-
:Gamma => StatsFuns.gammalogpdf
45-
)
46-
48+
replacements = Dict(:Normal => StatsFuns.normlogpdf, :Gamma => StatsFuns.gammalogpdf)
49+
4750
dsym = Symbol(d)
4851
if haskey(replacements, dsym)
4952
return replacements[dsym]
@@ -52,8 +55,8 @@ function getlogpdf(d, args)
5255
end
5356
end
5457

55-
const analytic_rule = @rule (~f::islogpdf)((~d::isdist)(~~args), ~x) => getlogpdf(~d, ~~args)(map(Symbolics.Num, (~~args))..., Symbolics.Num(~x))
56-
58+
const analytic_rule = @rule (~f::islogpdf)((~d::isdist)(~~args), ~x) =>
59+
getlogpdf(~d, ~~args)(map(Symbolics.Num, (~~args))..., Symbolics.Num(~x))
5760

5861
#################
5962
### Rewriters ###

src/varinfo.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,12 @@ end
124124

125125
function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector, lp::T) where {T}
126126
md = newmetadata(old_vi.metadata, Val(getspace(spl)), x)
127-
VarInfo(md, Base.RefValue{T}(lp), Ref(get_num_produce(old_vi)))
127+
return VarInfo(md, Base.RefValue{T}(lp), Ref(get_num_produce(old_vi)))
128128
end
129129

130130
function VarInfo{T}(old_vi::TypedVarInfo, spl, x::AbstractVector) where {T}
131131
md = newmetadata(old_vi.metadata, Val(getspace(spl)), x)
132-
VarInfo(md, Base.RefValue{T}(0.0), Ref(get_num_produce(old_vi)))
132+
return VarInfo(md, Base.RefValue{T}(0.0), Ref(get_num_produce(old_vi)))
133133
end
134134

135135
function VarInfo(

0 commit comments

Comments
 (0)