Skip to content
19 changes: 13 additions & 6 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function isassumption(expr::Union{Symbol,Expr})
true
else
# Evaluate the LHS
$expr === missing
$(maybe_view(expr)) === missing
end
end
end
Expand All @@ -36,6 +36,13 @@ end
# failsafe: a literal is never an assumption
isassumption(expr) = :(false)

# If we're working with, say, a `Symbol`, then we're not going to `view`.
maybe_view(x) = x
maybe_view(x::Expr) = :($(DynamicPPL.maybe_unwrap_view)(@view($x)))

maybe_unwrap_view(x) = x
maybe_unwrap_view(x::SubArray{<:Any, 0}) = x[1]

"""
isliteral(expr)

Expand Down Expand Up @@ -300,7 +307,7 @@ function generate_tilde(left, right)
if isliteral(left)
return quote
$(DynamicPPL.tilde_observe!)(
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
__context__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), __varinfo__
)
end
end
Expand All @@ -325,7 +332,7 @@ function generate_tilde(left, right)
$(DynamicPPL.tilde_observe!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$(maybe_view(left)),
$vn,
$inds,
__varinfo__,
Expand All @@ -344,7 +351,7 @@ function generate_dot_tilde(left, right)
if isliteral(left)
return quote
$(DynamicPPL.dot_tilde_observe!)(
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
__context__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), __varinfo__
)
end
end
Expand All @@ -360,7 +367,7 @@ function generate_dot_tilde(left, right)
$left .= $(DynamicPPL.dot_tilde_assume!)(
__context__,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $left, $vn
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
)...,
$inds,
__varinfo__,
Expand All @@ -369,7 +376,7 @@ function generate_dot_tilde(left, right)
$(DynamicPPL.dot_tilde_observe!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$(maybe_view(left)),
$vn,
$inds,
__varinfo__,
Expand Down