Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,14 @@ export AbstractVarInfo,
vectorize,
# Model
Model,
getmissings,
getargnames,
getargumentnames,
getarguments,
getconstantnames,
getconstants,
getobservationnames,
getobservations,
generated_quantities,
isobservation,
# Samplers
Sampler,
SampleFromPrior,
Expand Down
162 changes: 65 additions & 97 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,40 +1,6 @@
const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__rng__)
const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng)

"""
isassumption(expr)

Return an expression that can be evaluated to check if `expr` is an assumption in the
model.

Let `expr` be `:(x[1])`. It is an assumption in the following cases:
1. `x` is not among the input data to the model,
2. `x` is among the input data to the model but with a value `missing`, or
3. `x` is among the input data to the model with a value other than missing,
but `x[1] === missing`.

When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
"""
function isassumption(expr::Union{Symbol,Expr})
vn = gensym(:vn)

return quote
let $vn = $(varname(expr))
# This branch should compile nicely in all cases except for partial missing data
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
if !$(DynamicPPL.inargnames)($vn, __model__) ||
$(DynamicPPL.inmissings)($vn, __model__)
true
else
# Evaluate the LHS
$expr === missing
end
end
end
end

# failsafe: a literal is never an assumption
isassumption(expr) = :(false)

"""
isliteral(expr)
Expand Down Expand Up @@ -137,8 +103,14 @@ end
function model(mod, linenumbernode, expr, warn)
modelinfo = build_model_info(expr)

# Generate main body
modelinfo[:body] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn)
# Generate main body and find all variable symbols
modelinfo[:body], modelinfo[:varnames] = generate_mainbody(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should generate_mainbody get a new name? I figured there is no need to traverse the expression twice, so I just added the varname extraction here...

mod, modelinfo[:modeldef][:body], warn
)

# extract parameters and observations from that
modelinfo[:paramnames] = filter(x -> x ∉ modelinfo[:varnames], modelinfo[:allargs_syms])
modelinfo[:obsnames] = setdiff(modelinfo[:allargs_syms], modelinfo[:paramnames])

return build_output(modelinfo, linenumbernode)
end
Expand Down Expand Up @@ -167,8 +139,6 @@ function build_model_info(input_expr)
modelinfo = Dict(
:allargs_exprs => [],
:allargs_syms => [],
:allargs_namedtuple => NamedTuple(),
:defaults_namedtuple => NamedTuple(),
:modeldef => modeldef,
)
return modelinfo
Expand Down Expand Up @@ -196,28 +166,10 @@ function build_model_info(input_expr)
x_ => x
end
end

# Build named tuple expression of the argument symbols and variables of the same name.
allargs_namedtuple = to_namedtuple_expr(allargs_syms)

# Extract default values of the positional and keyword arguments.
default_syms = []
default_vals = []
for (sym, (expr, val)) in zip(allargs_syms, allargs_exprs_defaults)
if val !== NO_DEFAULT
push!(default_syms, sym)
push!(default_vals, val)
end
end

# Build named tuple expression of the argument symbols with default values.
defaults_namedtuple = to_namedtuple_expr(default_syms, default_vals)


modelinfo = Dict(
:allargs_exprs => allargs_exprs,
:allargs_syms => allargs_syms,
:allargs_namedtuple => allargs_namedtuple,
:defaults_namedtuple => defaults_namedtuple,
:modeldef => modeldef,
)

Expand All @@ -233,43 +185,50 @@ Generate the body of the main evaluation function from expression `expr` and arg
If `warn` is true, a warning is displayed if internal variables are used in the model
definition.
"""
generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn)
function generate_mainbody(mod, expr, warn)
varnames = Symbol[]
body = generate_mainbody!(mod, Symbol[], varnames, expr, warn)
return body, varnames
end

generate_mainbody!(mod, found, x, warn) = x
function generate_mainbody!(mod, found, sym::Symbol, warn)
generate_mainbody!(mod, found_internals, varnames, x, warn) = x
function generate_mainbody!(mod, found_internals, sym::Symbol, warn)
if sym in DEPRECATED_INTERNALNAMES
newsym = Symbol(:_, sym, :__)
Base.depwarn(
"internal variable `$sym` is deprecated, use `$newsym` instead.",
:generate_mainbody!,
)
return generate_mainbody!(mod, found, newsym, warn)
return generate_mainbody!(mod, found_internals, newsym, warn)
end

if warn && sym in INTERNALNAMES && sym ∉ found
if warn && sym in INTERNALNAMES && sym ∉ found_internals
@warn "you are using the internal variable `$sym`"
push!(found, sym)
push!(found_internals, sym)
end

return sym
end
function generate_mainbody!(mod, found, expr::Expr, warn)
function generate_mainbody!(mod, found_internals, varnames, expr::Expr, warn)
# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]

# If it's a macro, we expand it
if Meta.isexpr(expr, :macrocall)
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
return generate_mainbody!(
mod, found_internals, varnames, macroexpand(mod, expr; recursive=true), warn
)
end

# Modify dotted tilde operators.
args_dottilde = getargs_dottilde(expr)
if args_dottilde !== nothing
L, R = args_dottilde
!isliteral(L) && push!(varnames, vsym(L))
return Base.remove_linenums!(
generate_dot_tilde(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
generate_mainbody!(mod, found_internals, varnames, L, warn),
generate_mainbody!(mod, found_internals, varnames, R, warn),
),
)
end
Expand All @@ -278,15 +237,19 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
args_tilde = getargs_tilde(expr)
if args_tilde !== nothing
L, R = args_tilde
!isliteral(L) && push!(varnames, vsym(L))
return Base.remove_linenums!(
generate_tilde(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
generate_mainbody!(mod, found_internals, varnames, L, warn),
generate_mainbody!(mod, found_internals, varnames, R, warn),
),
)
end

return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
return Expr(
expr.head,
map(x -> generate_mainbody!(mod, found_internals, varnames, x, warn), expr.args)...,
)
end

"""
Expand All @@ -307,26 +270,26 @@ function generate_tilde(left, right)

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn inds isassumption
@gensym vn inds isobservation
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left = $(DynamicPPL.tilde_assume!)(
$isobservation = $(DynamicPPL.isobservation)($vn, __model__)
if $isobservation
$(DynamicPPL.tilde_observe!)(
__context__,
$(DynamicPPL.unwrap_right_vn)(
$(DynamicPPL.check_tilde_rhs)($right), $vn
)...,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$inds,
__varinfo__,
)
else
$(DynamicPPL.tilde_observe!)(
$left = $(DynamicPPL.tilde_assume!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$(DynamicPPL.unwrap_right_vn)(
$(DynamicPPL.check_tilde_rhs)($right), $vn
)...,
$inds,
__varinfo__,
)
Expand All @@ -351,26 +314,26 @@ function generate_dot_tilde(left, right)

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn inds isassumption
@gensym vn inds isobservation
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume!)(
$isobservation = $(DynamicPPL.isobservation)($vn, __model__)
if $isobservation
$(DynamicPPL.dot_tilde_observe!)(
__context__,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $left, $vn
)...,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$inds,
__varinfo__,
)
else
$(DynamicPPL.dot_tilde_observe!)(
$left .= $(DynamicPPL.dot_tilde_assume!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $left, $vn
)...,
$inds,
__varinfo__,
)
Expand Down Expand Up @@ -413,10 +376,15 @@ function build_output(modelinfo, linenumbernode)

## Build the model function.

# Extract the named tuple expression of all arguments and the default values.
allargs_namedtuple = modelinfo[:allargs_namedtuple]
defaults_namedtuple = modelinfo[:defaults_namedtuple]

# Extract the named tuple expression of all arguments
allargs_newnames = [gensym(x) for x in modelinfo[:allargs_syms]]
allargs_wrapped = [
x ∈ modelinfo[:obsnames] ? :($(DynamicPPL.Observation)($x)) : :($(DynamicPPL.Constant)($x))
for x in modelinfo[:allargs_syms]
]
allargs_decls = [:($name = $val) for (name, val) in zip(allargs_newnames, allargs_wrapped)]
allargs_namedtuple = to_namedtuple_expr(modelinfo[:allargs_syms], allargs_newnames)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# Update the function body of the user-specified model.
# We use a name for the anonymous evaluator that does not conflict with other variables.
modeldef = modelinfo[:modeldef]
Expand All @@ -427,11 +395,11 @@ function build_output(modelinfo, linenumbernode)
modeldef[:body] = MacroTools.@q begin
$(linenumbernode)
$evaluator = $(MacroTools.combinedef(evaluatordef))
$(allargs_decls...)
return $(DynamicPPL.Model)(
$(QuoteNode(modeldef[:name])),
$evaluator,
$allargs_namedtuple,
Comment on lines 413 to 415
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
$(QuoteNode(modeldef[:name])),
$evaluator,
$allargs_namedtuple,
$(QuoteNode(modeldef[:name])), $evaluator, $allargs_namedtuple

$defaults_namedtuple,
)
end

Expand Down
2 changes: 0 additions & 2 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg)))
require_gradient(spl::Sampler) = false
require_particles(spl::Sampler) = false

_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds))
_getindex(x, inds::Tuple{}) = x

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# assume
"""
Expand Down
Loading