Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2"
FlexiChains = "4a37a8b9-6e57-4b92-8664-298d46e639f7"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Expand All @@ -40,6 +40,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[weakdeps]
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[extensions]
TuringDynamicHMCExt = "DynamicHMC"
Expand All @@ -63,6 +64,7 @@ DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.40.6"
EllipticalSliceSampling = "0.5, 1, 2"
FlexiChains = "0.4"
ForwardDiff = "0.10.3, 1"
Libtask = "0.9.14"
LinearAlgebra = "1"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ links = InterLinks(
"AbstractMCMC" => "https://turinglang.org/AbstractMCMC.jl/stable/",
"ADTypes" => "https://sciml.github.io/ADTypes.jl/stable/",
"AdvancedVI" => "https://turinglang.org/AdvancedVI.jl/stable/",
"FlexiChains" => "https://pysm.dev/FlexiChains.jl/stable/",
"OrderedCollections" => "https://juliacollections.github.io/OrderedCollections.jl/stable/",
"Distributions" => "https://juliastats.org/Distributions.jl/stable/",
)
Expand Down
23 changes: 10 additions & 13 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,8 @@

## Module-wide re-exports

Turing.jl directly re-exports the entire public API of the following packages:

- [Distributions.jl](https://juliastats.org/Distributions.jl)
- [MCMCChains.jl](https://turinglang.org/MCMCChains.jl)

Please see the individual packages for their documentation.
Turing.jl directly re-exports the entire public API of [Distributions.jl](https://juliastats.org/Distributions.jl).
Please see its documentation for more details.

## Individual exports and re-exports

Expand Down Expand Up @@ -49,13 +45,14 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu

### Inference

| Exported symbol | Documentation | Description |
|:----------------- |:------------------------------------------------------------------------- |:----------------------------------------- |
| `sample` | [`StatsBase.sample`](https://turinglang.org/docs/usage/sampling-options/) | Sample from a model |
| `MCMCThreads` | [`AbstractMCMC.MCMCThreads`](@extref) | Run MCMC using multiple threads |
| `MCMCDistributed` | [`AbstractMCMC.MCMCDistributed`](@extref) | Run MCMC using multiple processes |
| `MCMCSerial` | [`AbstractMCMC.MCMCSerial`](@extref) | Run MCMC using without parallelism |
| `loadstate` | [`Turing.Inference.loadstate`](@ref) | Load saved state from `MCMCChains.Chains` |
| Exported symbol | Documentation | Description |
|:----------------- |:------------------------------------------------------------------------- |:----------------------------------- |
| `sample` | [`StatsBase.sample`](https://turinglang.org/docs/usage/sampling-options/) | Sample from a model |
| `MCMCThreads` | [`AbstractMCMC.MCMCThreads`](@extref) | Run MCMC using multiple threads |
| `MCMCDistributed` | [`AbstractMCMC.MCMCDistributed`](@extref) | Run MCMC using multiple processes |
| `MCMCSerial` | [`AbstractMCMC.MCMCSerial`](@extref) | Run MCMC using without parallelism |
| `loadstate` | [`Turing.Inference.loadstate`](@ref) | Load saved state from an MCMC chain |
| `VNChain` | n/a | Alias for `FlexiChain{VarName}` |

### Samplers

Expand Down
44 changes: 44 additions & 0 deletions ext/TuringMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
module TuringMCMCChainsExt

using Turing
using Turing: AbstractMCMC
using MCMCChains: MCMCChains

"""
loadstate(chain::MCMCChains.Chains)

Load the final state of the sampler from a `MCMCChains.Chains` object.

To save the final state of the sampler, you must use `sample(...; save_state=true)`. If this
argument was not used during sampling, calling `loadstate` will throw an error.
"""
function Turing.Inference.loadstate(chain::MCMCChains.Chains)
if !haskey(chain.info, :samplerstate)
throw(
ArgumentError(
"the chain object does not contain the final state of the sampler; to save the final state you must sample with `save_state=true`",
),
)
end
return chain.info[:samplerstate]
end

function AbstractMCMC.bundle_samples(
samples::Vector{<:Vector},
model::DynamicPPL.Model,
spl::Emcee,
state::EmceeState,
::Type{MCMCChains.Chains},
kwargs...,
)
n_walkers = _get_n_walkers(spl)
chains = map(1:n_walkers) do i
this_walker_samples = [s[i] for s in samples]
AbstractMCMC.bundle_samples(
this_walker_samples, model, spl, state, MCMCChains.Chains; kwargs...
)
end
return AbstractMCMC.chainscat(chains...)
end

end
7 changes: 5 additions & 2 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Reexport, ForwardDiff
using Bijectors, StatsFuns, SpecialFunctions
using Statistics, LinearAlgebra
using Libtask
@reexport using Distributions, MCMCChains
@reexport using Distributions
using Compat: pkgversion

using AdvancedVI: AdvancedVI
Expand All @@ -14,6 +14,7 @@ using LogDensityProblems: LogDensityProblems
using StatsAPI: StatsAPI
using StatsBase: StatsBase
using AbstractMCMC
using FlexiChains

using Printf: Printf
using Random: Random
Expand Down Expand Up @@ -181,6 +182,8 @@ export
loadstate,
# kwargs in SMC
might_produce,
@might_produce
@might_produce,
# FlexiChains re-export
VNChain

end
7 changes: 5 additions & 2 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using DynamicPPL:
Model,
DefaultContext
using Distributions, Libtask, Bijectors
using FlexiChains: FlexiChains, VNChain
using LinearAlgebra
using ..Turing: PROGRESS, Turing
using StatsFuns: logsumexp
Expand All @@ -35,7 +36,6 @@ import AdvancedPS
import EllipticalSliceSampling
import LogDensityProblems
import Random
import MCMCChains
import StatsBase: predict

export Hamiltonian,
Expand All @@ -62,7 +62,10 @@ export Hamiltonian,
init_strategy,
loadstate

const DEFAULT_CHAIN_TYPE = MCMCChains.Chains
const DEFAULT_CHAIN_TYPE = VNChain

# Extended in chains extensions
function loadstate end

include("abstractmcmc.jl")
include("repeat_sampler.jl")
Expand Down
19 changes: 0 additions & 19 deletions src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,25 +107,6 @@ function AbstractMCMC.sample(
)
end

"""
loadstate(chain::MCMCChains.Chains)

Load the final state of the sampler from a `MCMCChains.Chains` object.

To save the final state of the sampler, you must use `sample(...; save_state=true)`. If this
argument was not used during sampling, calling `loadstate` will throw an error.
"""
function loadstate(chain::MCMCChains.Chains)
if !haskey(chain.info, :samplerstate)
throw(
ArgumentError(
"the chain object does not contain the final state of the sampler; to save the final state you must sample with `save_state=true`",
),
)
end
return chain.info[:samplerstate]
end

# TODO(penelopeysm): Remove initialstep and generalise MCMC sampling procedures
function initialstep end

Expand Down
8 changes: 4 additions & 4 deletions src/mcmc/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,18 @@ function AbstractMCMC.step(
end

function AbstractMCMC.bundle_samples(
samples::Vector{<:Vector},
model::AbstractModel,
samples::Vector{<:AbstractVector},
model::DynamicPPL.Model,
spl::Emcee,
state::EmceeState,
chain_type::Type{MCMCChains.Chains};
chain_type::Type{VNChain};
kwargs...,
)
n_walkers = _get_n_walkers(spl)
chains = map(1:n_walkers) do i
this_walker_samples = [s[i] for s in samples]
AbstractMCMC.bundle_samples(
this_walker_samples, model, spl, state, chain_type; kwargs...
this_walker_samples, model, spl, state, VNChain; kwargs...
)
end
return AbstractMCMC.chainscat(chains...)
Expand Down
5 changes: 2 additions & 3 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
FlexiChains = "4a37a8b9-6e57-4b92-8664-298d46e639f7"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationBBO = "3e6eede4-6085-4f62-9a71-46d9bc1eb92b"
Expand Down Expand Up @@ -54,15 +54,14 @@ Combinatorics = "1"
DifferentiationInterface = "0.7"
Distributions = "0.25"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.40.6"
DynamicPPL = "0.40.13"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1"
HypothesisTests = "0.11"
Libtask = "0.9.14"
LinearAlgebra = "1"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.4"
MCMCChains = "7.3.0"
Mooncake = "0.4.182, 0.5"
Optimization = "3, 4, 5"
OptimizationBBO = "0.1, 0.2, 0.3, 0.4"
Expand Down
Loading
Loading