1- import Bijectors
2- import Symbolics
1+ using Bijectors : Bijectors
2+ using Symbolics : Symbolics
33using Symbolics. SymbolicUtils
44
55Symbolics. @register Bijectors. logpdf_with_trans (dist, r, istrans)
@@ -11,8 +11,12 @@ islogpdf(x) = false
1111
1212# HACK: Apparently this is needed for disambiguiation.
1313# TODO : Open issue.
14- Symbolics.:< ₑ (a:: Real , b:: Symbolics.Num ) = Symbolics.:< ₑ (Symbolics. value (a), Symbolics. value (b))
15- Symbolics.:< ₑ (a:: Symbolics.Num , b:: Real ) = Symbolics.:< ₑ (Symbolics. value (a), Symbolics. value (b))
14+ function Symbolics.:< ₑ (a:: Real , b:: Symbolics.Num )
15+ return Symbolics.:< ₑ (Symbolics. value (a), Symbolics. value (b))
16+ end
17+ function Symbolics.:< ₑ (a:: Symbolics.Num , b:: Real )
18+ return Symbolics.:< ₑ (Symbolics. value (a), Symbolics. value (b))
19+ end
1620
1721# ############
1822# ## Rules ###
@@ -22,13 +26,15 @@ const rmnum_rule = @rule (~x) => Symbolics.value(~x)
2226const addnum_rule = @rule (~ x) => Symbolics. Num (~ x)
2327
2428# In the case where we want to work directly with the `x ~ Distribution` statements, the following rules can be useful:
25- const logpdf_rule = @rule (~ x ~ ~ d) => Distributions. logpdf (Symbolics. Num (~ d), Symbolics. Num (~ x));
29+ const logpdf_rule = @rule (~ x ~ ~ d) =>
30+ Distributions. logpdf (Symbolics. Num (~ d), Symbolics. Num (~ x));
2631const rand_rule = @rule (~ x ~ ~ d) => Distributions. rand (Symbolics. Num (~ d))
2732
2833# We don't want to trace into `Bijectors.logpdf_with_trans`, so we just replace it with `logpdf`.
2934islogpdf_with_trans (f:: Function ) = f === Bijectors. logpdf_with_trans
3035islogpdf_with_trans (x) = false
31- const logpdf_with_trans_rule = @rule (~ f:: islogpdf_with_trans )(~ dist, ~ x, ~ istrans) => logpdf (~ dist, ~ x)
36+ const logpdf_with_trans_rule = @rule (~ f:: islogpdf_with_trans )(~ dist, ~ x, ~ istrans) =>
37+ logpdf (~ dist, ~ x)
3238
3339# Attempt to expand `logpdf` to get analytical expressions.
3440# The idea is that `getlogpdf(d, args)` should return a method of the following signature:
@@ -39,11 +45,8 @@ const logpdf_with_trans_rule = @rule (~f::islogpdf_with_trans)(~dist, ~x, ~istra
3945# HACK: this is very hacky but you get the idea
4046import Distributions: StatsFuns
4147function getlogpdf (d, args)
42- replacements = Dict (
43- :Normal => StatsFuns. normlogpdf,
44- :Gamma => StatsFuns. gammalogpdf
45- )
46-
48+ replacements = Dict (:Normal => StatsFuns. normlogpdf, :Gamma => StatsFuns. gammalogpdf)
49+
4750 dsym = Symbol (d)
4851 if haskey (replacements, dsym)
4952 return replacements[dsym]
@@ -52,8 +55,8 @@ function getlogpdf(d, args)
5255 end
5356end
5457
55- const analytic_rule = @rule (~ f:: islogpdf )((~ d:: isdist )(~~ args), ~ x) => getlogpdf ( ~ d, ~~ args)( map (Symbolics . Num, ( ~~ args)) ... , Symbolics . Num ( ~ x))
56-
58+ const analytic_rule = @rule (~ f:: islogpdf )((~ d:: isdist )(~~ args), ~ x) =>
59+ getlogpdf ( ~ d, ~~ args)( map (Symbolics . Num, ( ~~ args)) ... , Symbolics . Num ( ~ x))
5760
5861# ################
5962# ## Rewriters ###
0 commit comments