diff --git a/src/compiler.jl b/src/compiler.jl index 25530696d..829b08125 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2,7 +2,7 @@ const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__ const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng) """ - isassumption(expr) + isassumption(expr, vn) Return an expression that can be evaluated to check if `expr` is an assumption in the model. @@ -15,38 +15,38 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases: When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`. """ -function isassumption(expr::Union{Symbol,Expr}) - vn = gensym(:vn) - +function isassumption(expr::Union{Expr,Symbol}, vn=AbstractPPL.drop_escape(varname(expr))) return quote - let $vn = $(AbstractPPL.drop_escape(varname(expr))) - if $(DynamicPPL.contextual_isassumption)(__context__, $vn) - # Considered an assumption by `__context__` which means either: - # 1. We hit the default implementation, e.g. using `DefaultContext`, - # which in turn means that we haven't considered if it's one of - # the model arguments, hence we need to check this. - # 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments, - # i.e. we're trying to condition one of the latent variables. - # In this case, the below will return `true` since the first branch - # will be hit. - # 3. We are working with a `ConditionContext` _and_ it's in the model arguments, - # i.e. we're trying to override the value. This is currently NOT supported. - # TODO: Support by adding context to model, and use `model.args` - # as the default conditioning. Then we no longer need to check `inargnames` - # since it will all be handled by `contextual_isassumption`. - if !($(DynamicPPL.inargnames)($vn, __model__)) || - $(DynamicPPL.inmissings)($vn, __model__) - true - else - $(maybe_view(expr)) === missing - end + if $(DynamicPPL.contextual_isassumption)(__context__, $vn) + # Considered an assumption by `__context__` which means either: + # 1. We hit the default implementation, e.g. using `DefaultContext`, + # which in turn means that we haven't considered if it's one of + # the model arguments, hence we need to check this. + # 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments, + # i.e. we're trying to condition one of the latent variables. + # In this case, the below will return `true` since the first branch + # will be hit. + # 3. We are working with a `ConditionContext` _and_ it's in the model arguments, + # i.e. we're trying to override the value. This is currently NOT supported. + # TODO: Support by adding context to model, and use `model.args` + # as the default conditioning. Then we no longer need to check `inargnames` + # since it will all be handled by `contextual_isassumption`. + if !($(DynamicPPL.inargnames)($vn, __model__)) || + $(DynamicPPL.inmissings)($vn, __model__) + true else - false + $(maybe_view(expr)) === missing end + else + false end end end +# failsafe: a literal is never an assumption +isassumption(expr, vn) = :(false) +isassumption(expr) = :(false) + """ contextual_isassumption(context, vn) @@ -80,9 +80,6 @@ function contextual_isassumption(context::PrefixContext, vn) return contextual_isassumption(childcontext(context), prefix(context, vn)) end -# failsafe: a literal is never an assumption -isassumption(expr) = :(false) - # If we're working with, say, a `Symbol`, then we're not going to `view`. maybe_view(x) = x maybe_view(x::Expr) = :($(DynamicPPL.maybe_unwrap_view)(@views($x))) @@ -396,7 +393,7 @@ function generate_tilde(left, right) # more selective with our escape. Until that's the case, we remove them all. return quote $vn = $(AbstractPPL.drop_escape(varname(left))) - $isassumption = $(DynamicPPL.isassumption(left)) + $isassumption = $(DynamicPPL.isassumption(left, vn)) if $isassumption $(generate_tilde_assume(left, right, vn)) else @@ -417,8 +414,8 @@ function generate_tilde(left, right) end function generate_tilde_assume(left, right, vn) - expr = :( - $left = $(DynamicPPL.tilde_assume!)( + tilde = :( + $(DynamicPPL.tilde_assume!)( __context__, $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., __varinfo__, @@ -426,11 +423,15 @@ function generate_tilde_assume(left, right, vn) ) return if left isa Expr - AbstractPPL.drop_escape( - Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true) - ) + # `x[i] = ...` needs to become `x = set(x, @lens(_[i]), ...)` + @gensym lens + vn_name = AbstractPPL.vsym(left) + quote + $lens = $(BangBang.prefermutation)($(DynamicPPL.getindexing)($vn)) + $vn_name = $(Setfield.set)($vn_name, $lens, $tilde) + end else - return expr + return :($left = $tilde) end end @@ -447,7 +448,7 @@ function generate_dot_tilde(left, right) @gensym vn isassumption return quote $vn = $(AbstractPPL.drop_escape(varname(left))) - $isassumption = $(DynamicPPL.isassumption(left)) + $isassumption = $(DynamicPPL.isassumption(left, vn)) if $isassumption $(generate_dot_tilde_assume(left, right, vn)) else