Skip to content
91 changes: 45 additions & 46 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__rng__)
const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng)

for name in INTERNALNAMES
@eval $(Symbol(uppercase(string(name)))) = $(Meta.quot((name)))
end

"""
isassumption(expr)

Expand All @@ -16,11 +20,9 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases:
When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
"""
function isassumption(expr::Union{Symbol,Expr})
vn = gensym(:vn)

return quote
let $vn = $(varname(expr))
if $(DynamicPPL.contextual_isassumption)(__context__, $vn)
let vn = $(varname(expr))
if $(DynamicPPL.contextual_isassumption)($__CONTEXT__, vn)
# Considered an assumption by `__context__` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
Expand All @@ -34,8 +36,8 @@ function isassumption(expr::Union{Symbol,Expr})
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
if !($(DynamicPPL.inargnames)($vn, __model__)) ||
$(DynamicPPL.inmissings)($vn, __model__)
if !($(DynamicPPL.inargnames)(vn, $__MODEL__)) ||
$(DynamicPPL.inmissings)(vn, $__MODEL__)
true
else
$(maybe_view(expr)) === missing
Expand Down Expand Up @@ -201,7 +203,7 @@ To generate a `Model`, call `model(xvalue)` or `model(xvalue, yvalue)`.
macro model(expr, warn=false)
# include `LineNumberNode` with information about the call site in the
# generated function for easier debugging and interpretation of error messages
return esc(model(__module__, __source__, expr, warn))
return model(__module__, __source__, expr, warn)
end

function model(mod, linenumbernode, expr, warn)
Expand Down Expand Up @@ -325,7 +327,7 @@ function generate_mainbody!(mod, found, sym::Symbol, warn)
end
function generate_mainbody!(mod, found, expr::Expr, warn)
# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]
Meta.isexpr(expr, :$) && return esc(expr.args[1])

# If it's a macro, we expand it
if Meta.isexpr(expr, :macrocall)
Expand Down Expand Up @@ -370,40 +372,39 @@ function generate_tilde(left, right)
if isliteral(left)
return quote
$(DynamicPPL.tilde_observe!)(
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
$__CONTEXT__, $(DynamicPPL.check_tilde_rhs)($right), $left, $__VARINFO__
)
end
end

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn inds isassumption
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
vn = $(varname(left))
inds = $(vinds(left))
isassumption = $(DynamicPPL.isassumption(left))
if isassumption
$left = $(DynamicPPL.tilde_assume!)(
__context__,
$__CONTEXT__,
$(DynamicPPL.unwrap_right_vn)(
$(DynamicPPL.check_tilde_rhs)($right), $vn
$(DynamicPPL.check_tilde_rhs)($right), vn
)...,
$inds,
__varinfo__,
inds,
$__VARINFO__,
)
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left = $(DynamicPPL.getvalue_nested)(__context__, $vn)
if !$(DynamicPPL.inargnames)(vn, $__MODEL__)
$left = $(DynamicPPL.getvalue_nested)($__CONTEXT__, vn)
end

$(DynamicPPL.tilde_observe!)(
__context__,
$__CONTEXT__,
$(DynamicPPL.check_tilde_rhs)($right),
$(maybe_view(left)),
$vn,
$inds,
__varinfo__,
vn,
inds,
$__VARINFO__,
)
end
end
Expand All @@ -419,40 +420,39 @@ function generate_dot_tilde(left, right)
if isliteral(left)
return quote
$(DynamicPPL.dot_tilde_observe!)(
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
$__CONTEXT__, $(DynamicPPL.check_tilde_rhs)($right), $left, $__VARINFO__
)
end
end

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn inds isassumption
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
vn = $(varname(left))
inds = $(vinds(left))
isassumption = $(DynamicPPL.isassumption(left))
if isassumption
$left .= $(DynamicPPL.dot_tilde_assume!)(
__context__,
$__CONTEXT__,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), vn
)...,
$inds,
__varinfo__,
inds,
$__VARINFO__,
)
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left .= $(DynamicPPL.getvalue_nested)(__context__, $vn)
if !$(DynamicPPL.inargnames)(vn, $__MODEL__)
$left .= $(DynamicPPL.getvalue_nested)($__CONTEXT__, vn)
end

$(DynamicPPL.dot_tilde_observe!)(
__context__,
$__CONTEXT__,
$(DynamicPPL.check_tilde_rhs)($right),
$(maybe_view(left)),
$vn,
$inds,
__varinfo__,
vn,
inds,
$__VARINFO__,
)
end
end
Expand All @@ -478,9 +478,9 @@ function build_output(modelinfo, linenumbernode)
# Add the internal arguments to the user-specified arguments (positional + keywords).
evaluatordef[:args] = vcat(
[
:(__model__::$(DynamicPPL.Model)),
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
:(__context__::$(DynamicPPL.AbstractContext)),
:($__MODEL__::$(DynamicPPL.Model)),
:($__VARINFO__::$(DynamicPPL.AbstractVarInfo)),
:($__CONTEXT__::$(DynamicPPL.AbstractContext)),
],
modelinfo[:allargs_exprs],
)
Expand All @@ -500,16 +500,15 @@ function build_output(modelinfo, linenumbernode)
# Update the function body of the user-specified model.
# We use a name for the anonymous evaluator that does not conflict with other variables.
modeldef = modelinfo[:modeldef]
@gensym evaluator
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
# to the call site
modeldef[:body] = MacroTools.@q begin
$(linenumbernode)
$evaluator = $(MacroTools.combinedef(evaluatordef))
evaluator = $(MacroTools.combinedef(evaluatordef))
return $(DynamicPPL.Model)(
$(QuoteNode(modeldef[:name])),
$evaluator,
evaluator,
$allargs_namedtuple,
$defaults_namedtuple,
)
Expand Down