Skip to content
89 changes: 54 additions & 35 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
const INTERNALNAMES = (:__model__, :__context__, :__varinfo__)
const DEPRECATED_INTERNALNAMES = (:_model, :_context, :_varinfo)

for name in INTERNALNAMES
@eval $(Symbol(uppercase(string(name)))) = $(Meta.quot(name))
end

# macro _id(expr)
# return expr
# end

# macro hygienize(expr)
# return Meta.quot(macroexpand(__module__, :(@_id $expr)))
# end

"""
isassumption(expr)
isassumption(expr, vn)

Return an expression that can be evaluated to check if `expr` is an assumption in the
model.
Expand All @@ -14,38 +27,37 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases:

When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
"""
function isassumption(expr::Union{Symbol,Expr})
vn = gensym(:vn)

function isassumption(expr::Union{Symbol,Expr}, vn)
return quote
let $vn = $(varname(expr))
if $(DynamicPPL.contextual_isassumption)(__context__, $vn)
# Considered an assumption by `__context__` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
# the model arguments, hence we need to check this.
# 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments,
# i.e. we're trying to condition one of the latent variables.
# In this case, the below will return `true` since the first branch
# will be hit.
# 3. We are working with a `ConditionContext` _and_ it's in the model arguments,
# i.e. we're trying to override the value. This is currently NOT supported.
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
if !($(DynamicPPL.inargnames)($vn, __model__)) ||
$(DynamicPPL.inmissings)($vn, __model__)
true
else
$(maybe_view(expr)) === missing
end
if $(DynamicPPL.contextual_isassumption)(__context__, $vn)
# Considered an assumption by `__context__` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
# the model arguments, hence we need to check this.
# 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments,
# i.e. we're trying to condition one of the latent variables.
# In this case, the below will return `true` since the first branch
# will be hit.
# 3. We are working with a `ConditionContext` _and_ it's in the model arguments,
# i.e. we're trying to override the value. This is currently NOT supported.
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
if !($(DynamicPPL.inargnames)($vn, __model__)) ||
$(DynamicPPL.inmissings)($vn, __model__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
$(DynamicPPL.inmissings)($vn, __model__)
$(DynamicPPL.inmissings)($vn, __model__)

true
else
false
$(maybe_view(expr)) === missing
end
else
false
end
end
end

# failsafe: a literal is never an assumption
isassumption(expr, vn) = :(false)

"""
contextual_isassumption(context, vn)

Expand Down Expand Up @@ -79,9 +91,6 @@ function contextual_isassumption(context::PrefixContext, vn)
return contextual_isassumption(childcontext(context), prefix(context, vn))
end

# failsafe: a literal is never an assumption
isassumption(expr) = :(false)

# If we're working with, say, a `Symbol`, then we're not going to `view`.
maybe_view(x) = x
maybe_view(x::Expr) = :(@views($x))
Expand Down Expand Up @@ -306,6 +315,15 @@ generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, war

generate_mainbody!(mod, found, x, warn) = x
function generate_mainbody!(mod, found, sym::Symbol, warn)
if sym in DEPRECATED_INTERNALNAMES
newsym = Symbol(:_, sym, :__)
Base.depwarn(
"internal variable `$sym` is deprecated, use `$newsym` instead.",
:generate_mainbody!,
)
return generate_mainbody!(mod, found, newsym, warn)
end

if warn && sym in INTERNALNAMES && sym ∉ found
@warn "you are using the internal variable `$sym`"
push!(found, sym)
Expand Down Expand Up @@ -371,7 +389,7 @@ function generate_tilde(left, right)
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $isassumption
$left = $(DynamicPPL.tilde_assume!)(
__context__,
Expand Down Expand Up @@ -420,7 +438,7 @@ function generate_dot_tilde(left, right)
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume!)(
__context__,
Expand Down Expand Up @@ -465,15 +483,16 @@ function build_output(modelinfo, linenumbernode)
# Add the internal arguments to the user-specified arguments (positional + keywords).
evaluatordef[:args] = vcat(
[
:(__model__::$(DynamicPPL.Model)),
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
:(__context__::$(DynamicPPL.AbstractContext)),
:($__MODEL__::$(DynamicPPL.Model)),
:($__VARINFO__::$(DynamicPPL.AbstractVarInfo)),
:($__CONTEXT__::$(DynamicPPL.AbstractContext)),
],
modelinfo[:allargs_exprs],
)

# Delete the keyword arguments.
evaluatordef[:kwargs] = []
evaluatordef[:name] = esc(evaluatordef[:name])

# Replace the user-provided function body with the version created by DynamicPPL.
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
Expand All @@ -485,7 +504,6 @@ function build_output(modelinfo, linenumbernode)
end

## Build the model function.

# Extract the named tuple expression of all arguments and the default values.
allargs_namedtuple = modelinfo[:allargs_namedtuple]
defaults_namedtuple = modelinfo[:defaults_namedtuple]
Expand All @@ -495,6 +513,7 @@ function build_output(modelinfo, linenumbernode)
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
# to the call site
modeldef = modelinfo[:modeldef]
modeldef[:name] = esc(modeldef[:name])
modeldef[:body] = MacroTools.@q begin
$(linenumbernode)
return $(DynamicPPL.Model)(
Expand Down