@@ -3,6 +3,7 @@ using DynamicPPL:
33 AccumulatorTuple,
44 InitContext,
55 InitFromParams,
6+ AbstractInitStrategy,
67 LogJacobianAccumulator,
78 LogLikelihoodAccumulator,
89 LogPriorAccumulator,
@@ -28,6 +29,60 @@ using LogDensityProblems: LogDensityProblems
2829import DifferentiationInterface as DI
2930using Random: Random
3031
32+ """
33+ DynamicPPL.Experimental.fast_evaluate!!(
34+ [rng::Random.AbstractRNG,]
35+ model::Model,
36+ strategy::AbstractInitStrategy,
37+ accs::AccumulatorTuple, params::AbstractVector{<:Real}
38+ )
39+
40+ Evaluate a model using parameters obtained via `strategy`, and only computing the results in
41+ the provided accumulators.
42+
43+ It is assumed that the accumulators passed in have been initialised to appropriate values,
44+ as this function will not reset them. The default constructors for each accumulator will do
45+ this for you correctly.
46+
47+ Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs`
48+ argument may be mutated (depending on how the accumulators are implemented); hence the `!!`
49+ in the function name.
50+ """
51+ @inline function fast_evaluate!! (
52+ # Note that this `@inline` is mandatory for performance. If it's not inlined, it leads
53+ # to extra allocations (even for trivial models) and much slower runtime.
54+ rng:: Random.AbstractRNG ,
55+ model:: Model ,
56+ strategy:: AbstractInitStrategy ,
57+ accs:: AccumulatorTuple ,
58+ )
59+ ctx = InitContext (rng, strategy)
60+ model = DynamicPPL. setleafcontext (model, ctx)
61+ # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
62+ # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
63+ # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
64+ # here.
65+ # TODO (penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
66+ # it _should_ do, but this is wrong regardless.
67+ # https://github.com/TuringLang/DynamicPPL.jl/issues/1086
68+ vi = if Threads. nthreads () > 1
69+ param_eltype = DynamicPPL. get_param_eltype (strategy)
70+ accs = map (accs) do acc
71+ DynamicPPL. convert_eltype (float_type_with_fallback (param_eltype), acc)
72+ end
73+ ThreadSafeVarInfo (OnlyAccsVarInfo (accs))
74+ else
75+ OnlyAccsVarInfo (accs)
76+ end
77+ return DynamicPPL. _evaluate!! (model, vi)
78+ end
79+ @inline function fast_evaluate!! (
80+ model:: Model , strategy:: AbstractInitStrategy , accs:: AccumulatorTuple
81+ )
82+ # This `@inline` is also mandatory for performance
83+ return fast_evaluate!! (Random. default_rng (), model, strategy, accs)
84+ end
85+
3186"""
3287 FastLDF(
3388 model::Model,
@@ -213,31 +268,11 @@ struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
213268 varname_ranges:: Dict{VarName,RangeAndLinked}
214269end
215270function (f:: FastLogDensityAt )(params:: AbstractVector{<:Real} )
216- ctx = InitContext (
217- Random. default_rng (),
218- InitFromParams (
219- VectorWithRanges (f. iden_varname_ranges, f. varname_ranges, params), nothing
220- ),
271+ strategy = InitFromParams (
272+ VectorWithRanges (f. iden_varname_ranges, f. varname_ranges, params), nothing
221273 )
222- model = DynamicPPL. setleafcontext (f. model, ctx)
223274 accs = fast_ldf_accs (f. getlogdensity)
224- # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
225- # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
226- # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
227- # here.
228- # TODO (penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
229- # it _should_ do, but this is wrong regardless.
230- # https://github.com/TuringLang/DynamicPPL.jl/issues/1086
231- vi = if Threads. nthreads () > 1
232- accs = map (
233- acc -> DynamicPPL. convert_eltype (float_type_with_fallback (eltype (params)), acc),
234- accs,
235- )
236- ThreadSafeVarInfo (OnlyAccsVarInfo (accs))
237- else
238- OnlyAccsVarInfo (accs)
239- end
240- _, vi = DynamicPPL. _evaluate!! (model, vi)
275+ _, vi = fast_evaluate!! (f. model, strategy, accs)
241276 return f. getlogdensity (vi)
242277end
243278
0 commit comments