diff --git a/Project.toml b/Project.toml index 38b0736b74..6b9ff305ea 100644 --- a/Project.toml +++ b/Project.toml @@ -20,11 +20,11 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2" +FlexiChains = "4a37a8b9-6e57-4b92-8664-298d46e639f7" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" @@ -40,6 +40,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [weakdeps] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" [extensions] TuringDynamicHMCExt = "DynamicHMC" @@ -63,6 +64,7 @@ DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" DynamicPPL = "0.40.6" EllipticalSliceSampling = "0.5, 1, 2" +FlexiChains = "0.4" ForwardDiff = "0.10.3, 1" Libtask = "0.9.14" LinearAlgebra = "1" diff --git a/docs/make.jl b/docs/make.jl index 857cfbcb3a..59550ea433 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -11,6 +11,7 @@ links = InterLinks( "AbstractMCMC" => "https://turinglang.org/AbstractMCMC.jl/stable/", "ADTypes" => "https://sciml.github.io/ADTypes.jl/stable/", "AdvancedVI" => "https://turinglang.org/AdvancedVI.jl/stable/", + "FlexiChains" => "https://pysm.dev/FlexiChains.jl/stable/", "OrderedCollections" => "https://juliacollections.github.io/OrderedCollections.jl/stable/", "Distributions" => "https://juliastats.org/Distributions.jl/stable/", ) diff --git a/docs/src/api.md b/docs/src/api.md index ca08e166bb..5676b33dbc 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -2,12 +2,8 @@ ## Module-wide re-exports -Turing.jl directly re-exports the entire public API of the following packages: - - - [Distributions.jl](https://juliastats.org/Distributions.jl) - - [MCMCChains.jl](https://turinglang.org/MCMCChains.jl) - -Please see the individual packages for their documentation. +Turing.jl directly re-exports the entire public API of [Distributions.jl](https://juliastats.org/Distributions.jl). +Please see its documentation for more details. ## Individual exports and re-exports @@ -49,13 +45,14 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu ### Inference -| Exported symbol | Documentation | Description | -|:----------------- |:------------------------------------------------------------------------- |:----------------------------------------- | -| `sample` | [`StatsBase.sample`](https://turinglang.org/docs/usage/sampling-options/) | Sample from a model | -| `MCMCThreads` | [`AbstractMCMC.MCMCThreads`](@extref) | Run MCMC using multiple threads | -| `MCMCDistributed` | [`AbstractMCMC.MCMCDistributed`](@extref) | Run MCMC using multiple processes | -| `MCMCSerial` | [`AbstractMCMC.MCMCSerial`](@extref) | Run MCMC using without parallelism | -| `loadstate` | [`Turing.Inference.loadstate`](@ref) | Load saved state from `MCMCChains.Chains` | +| Exported symbol | Documentation | Description | +|:----------------- |:------------------------------------------------------------------------- |:----------------------------------- | +| `sample` | [`StatsBase.sample`](https://turinglang.org/docs/usage/sampling-options/) | Sample from a model | +| `MCMCThreads` | [`AbstractMCMC.MCMCThreads`](@extref) | Run MCMC using multiple threads | +| `MCMCDistributed` | [`AbstractMCMC.MCMCDistributed`](@extref) | Run MCMC using multiple processes | +| `MCMCSerial` | [`AbstractMCMC.MCMCSerial`](@extref) | Run MCMC using without parallelism | +| `loadstate` | [`Turing.Inference.loadstate`](@ref) | Load saved state from an MCMC chain | +| `VNChain` | n/a | Alias for `FlexiChain{VarName}` | ### Samplers diff --git a/ext/TuringMCMCChainsExt.jl b/ext/TuringMCMCChainsExt.jl new file mode 100644 index 0000000000..80fca50ab8 --- /dev/null +++ b/ext/TuringMCMCChainsExt.jl @@ -0,0 +1,44 @@ +module TuringMCMCChainsExt + +using Turing +using Turing: AbstractMCMC +using MCMCChains: MCMCChains + +""" + loadstate(chain::MCMCChains.Chains) + +Load the final state of the sampler from a `MCMCChains.Chains` object. + +To save the final state of the sampler, you must use `sample(...; save_state=true)`. If this +argument was not used during sampling, calling `loadstate` will throw an error. +""" +function Turing.Inference.loadstate(chain::MCMCChains.Chains) + if !haskey(chain.info, :samplerstate) + throw( + ArgumentError( + "the chain object does not contain the final state of the sampler; to save the final state you must sample with `save_state=true`", + ), + ) + end + return chain.info[:samplerstate] +end + +function AbstractMCMC.bundle_samples( + samples::Vector{<:Vector}, + model::DynamicPPL.Model, + spl::Emcee, + state::EmceeState, + ::Type{MCMCChains.Chains}, + kwargs..., +) + n_walkers = _get_n_walkers(spl) + chains = map(1:n_walkers) do i + this_walker_samples = [s[i] for s in samples] + AbstractMCMC.bundle_samples( + this_walker_samples, model, spl, state, MCMCChains.Chains; kwargs... + ) + end + return AbstractMCMC.chainscat(chains...) +end + +end diff --git a/src/Turing.jl b/src/Turing.jl index 87f09a4533..c06229e4c6 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -4,7 +4,7 @@ using Reexport, ForwardDiff using Bijectors, StatsFuns, SpecialFunctions using Statistics, LinearAlgebra using Libtask -@reexport using Distributions, MCMCChains +@reexport using Distributions using Compat: pkgversion using AdvancedVI: AdvancedVI @@ -14,6 +14,7 @@ using LogDensityProblems: LogDensityProblems using StatsAPI: StatsAPI using StatsBase: StatsBase using AbstractMCMC +using FlexiChains using Printf: Printf using Random: Random @@ -181,6 +182,8 @@ export loadstate, # kwargs in SMC might_produce, - @might_produce + @might_produce, + # FlexiChains re-export + VNChain end diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 73f0661e61..26d4a6aeb4 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -16,6 +16,7 @@ using DynamicPPL: Model, DefaultContext using Distributions, Libtask, Bijectors +using FlexiChains: FlexiChains, VNChain using LinearAlgebra using ..Turing: PROGRESS, Turing using StatsFuns: logsumexp @@ -35,7 +36,6 @@ import AdvancedPS import EllipticalSliceSampling import LogDensityProblems import Random -import MCMCChains import StatsBase: predict export Hamiltonian, @@ -62,7 +62,10 @@ export Hamiltonian, init_strategy, loadstate -const DEFAULT_CHAIN_TYPE = MCMCChains.Chains +const DEFAULT_CHAIN_TYPE = VNChain + +# Extended in chains extensions +function loadstate end include("abstractmcmc.jl") include("repeat_sampler.jl") diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index f867ab5344..85bbd693da 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -107,25 +107,6 @@ function AbstractMCMC.sample( ) end -""" - loadstate(chain::MCMCChains.Chains) - -Load the final state of the sampler from a `MCMCChains.Chains` object. - -To save the final state of the sampler, you must use `sample(...; save_state=true)`. If this -argument was not used during sampling, calling `loadstate` will throw an error. -""" -function loadstate(chain::MCMCChains.Chains) - if !haskey(chain.info, :samplerstate) - throw( - ArgumentError( - "the chain object does not contain the final state of the sampler; to save the final state you must sample with `save_state=true`", - ), - ) - end - return chain.info[:samplerstate] -end - # TODO(penelopeysm): Remove initialstep and generalise MCMC sampling procedures function initialstep end diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index a776579fbe..1e166d44a8 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -118,18 +118,18 @@ function AbstractMCMC.step( end function AbstractMCMC.bundle_samples( - samples::Vector{<:Vector}, - model::AbstractModel, + samples::Vector{<:AbstractVector}, + model::DynamicPPL.Model, spl::Emcee, state::EmceeState, - chain_type::Type{MCMCChains.Chains}; + chain_type::Type{VNChain}; kwargs..., ) n_walkers = _get_n_walkers(spl) chains = map(1:n_walkers) do i this_walker_samples = [s[i] for s in samples] AbstractMCMC.bundle_samples( - this_walker_samples, model, spl, state, chain_type; kwargs... + this_walker_samples, model, spl, state, VNChain; kwargs... ) end return AbstractMCMC.chainscat(chains...) diff --git a/test/Project.toml b/test/Project.toml index dee67a1dcd..75bf1bc80e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -15,6 +15,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +FlexiChains = "4a37a8b9-6e57-4b92-8664-298d46e639f7" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" @@ -22,7 +23,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationBBO = "3e6eede4-6085-4f62-9a71-46d9bc1eb92b" @@ -54,7 +54,7 @@ Combinatorics = "1" DifferentiationInterface = "0.7" Distributions = "0.25" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.40.6" +DynamicPPL = "0.40.13" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1" HypothesisTests = "0.11" @@ -62,7 +62,6 @@ Libtask = "0.9.14" LinearAlgebra = "1" LogDensityProblems = "2" LogDensityProblemsAD = "1.4" -MCMCChains = "7.3.0" Mooncake = "0.4.182, 0.5" Optimization = "3, 4, 5" OptimizationBBO = "0.1, 0.2, 0.3, 0.4" diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 73f83040fc..a51bc474fd 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -7,9 +7,9 @@ using Distributions: sample using AbstractMCMC: AbstractMCMC import DynamicPPL using DynamicPPL: filldist +using FlexiChains: FlexiChains, VNChain import ForwardDiff using LinearAlgebra: I -import MCMCChains import Random using Random: Xoshiro import ReverseDiff @@ -42,9 +42,7 @@ using Turing rng2 = Xoshiro(5) chain2 = sample(rng2, model, sampler, MCMCThreads(), 10, 4) - # For HMC, the first step does not have stats, so we need to use isequal to - # avoid comparing `missing`s - @test isequal(chain1.value, chain2.value) + @test FlexiChains.has_same_data(chain1, chain2) end # Should also be stable with an explicit RNG @@ -57,7 +55,7 @@ using Turing Random.seed!(rng, local_seed) chain2 = sample(rng, model, sampler, MCMCThreads(), 10, 4) - @test isequal(chain1.value, chain2.value) + @test FlexiChains.has_same_data(chain1, chain2) end end @@ -85,28 +83,32 @@ using Turing @testset "single-chain" begin chn1 = sample(demo(), StaticSampler(), 10; save_state=true) - @test chn1.info.samplerstate isa DynamicPPL.AbstractVarInfo - chn2 = sample(demo(), StaticSampler(), 10; initial_state=loadstate(chn1)) - xval = chn1[:x][1] - @test all(chn2[:x] .== xval) + @test FlexiChains.last_sampler_state(chn1) isa + AbstractVector{<:DynamicPPL.AbstractVarInfo} + @test length(FlexiChains.last_sampler_state(chn1)) == 1 + chn2 = sample(demo(), StaticSampler(), 10; initial_state=only(loadstate(chn1))) + xval = chn1[@varname(x)][1] + @test all(chn2[@varname(x)] .== xval) end @testset "multiple-chain" for nchains in [1, 3] chn1 = sample( demo(), StaticSampler(), MCMCThreads(), 10, nchains; save_state=true ) - @test chn1.info.samplerstate isa AbstractVector{<:DynamicPPL.AbstractVarInfo} - @test length(chn1.info.samplerstate) == nchains + @test FlexiChains.last_sampler_state(chn1) isa + AbstractVector{<:DynamicPPL.AbstractVarInfo} + @test length(FlexiChains.last_sampler_state(chn1)) == nchains + niters = 10 chn2 = sample( demo(), StaticSampler(), MCMCThreads(), - 10, + niters, nchains; initial_state=loadstate(chn1), ) - xval = chn1[:x][1, :] - @test all(i -> chn2[:x][i, :] == xval, 1:10) + xval = chn1[@varname(x), iter=1] + @test all(i -> chn2[@varname(x), iter=i] == xval, 1:niters) end end @@ -119,12 +121,12 @@ using Turing check_gdemo(chn1) chn1_contd = sample( - StableRNG(seed), gdemo_default, alg1, 2_000; initial_state=loadstate(chn1) + StableRNG(seed), gdemo_default, alg1, 2_000; initial_state=only(loadstate(chn1)) ) check_gdemo(chn1_contd) chn1_contd2 = sample( - StableRNG(seed), gdemo_default, alg1, 2_000; initial_state=loadstate(chn1) + StableRNG(seed), gdemo_default, alg1, 2_000; initial_state=only(loadstate(chn1)) ) check_gdemo(chn1_contd2) @@ -139,7 +141,7 @@ using Turing check_gdemo(chn2) chn2_contd = sample( - StableRNG(seed), gdemo_default, alg2, 2_000; initial_state=loadstate(chn2) + StableRNG(seed), gdemo_default, alg2, 2_000; initial_state=only(loadstate(chn2)) ) check_gdemo(chn2_contd) @@ -154,7 +156,7 @@ using Turing check_gdemo(chn3) chn3_contd = sample( - StableRNG(seed), gdemo_default, alg3, 5_000; initial_state=loadstate(chn3) + StableRNG(seed), gdemo_default, alg3, 5_000; initial_state=only(loadstate(chn3)) ) check_gdemo(chn3_contd) end @@ -164,16 +166,16 @@ using Turing @testset "Single-threaded vanilla" begin chains = sample(StableRNG(seed), gdemo_d(), Prior(), N) - @test chains isa MCMCChains.Chains - @test mean(chains, :s) ≈ 3 atol = 0.11 - @test mean(chains, :m) ≈ 0 atol = 0.1 + @test chains isa VNChain + @test mean(chains[@varname(s)]) ≈ 3 atol = 0.11 + @test mean(chains[@varname(m)]) ≈ 0 atol = 0.1 end @testset "Multi-threaded" begin chains = sample(StableRNG(seed), gdemo_d(), Prior(), MCMCThreads(), N, 4) - @test chains isa MCMCChains.Chains - @test mean(chains, :s) ≈ 3 atol = 0.11 - @test mean(chains, :m) ≈ 0 atol = 0.1 + @test chains isa VNChain + @test mean(chains[@varname(s)]) ≈ 3 atol = 0.11 + @test mean(chains[@varname(m)]) ≈ 0 atol = 0.1 end @testset "accumulators are set correctly" begin @@ -187,39 +189,29 @@ using Turing return nothing end chain = sample(coloneq(), Prior(), N) - @test chain isa MCMCChains.Chains - @test all(x -> x == 1.0, chain[:z]) + @test chain isa VNChain + @test all(x -> x == 1.0, chain[@varname(z)]) # And for the same reason we should also make sure that the logp # components are correctly calculated. - @test isapprox(chain[:logprior], logpdf.(Normal(), chain[:x])) - @test isapprox(chain[:loglikelihood], logpdf.(Normal.(chain[:x]), 10.0)) + @test isapprox(chain[:logprior], logpdf.(Normal(), chain[@varname(x)])) + @test isapprox( + chain[:loglikelihood], logpdf.(Normal.(chain[@varname(x)]), 10.0) + ) @test isapprox(chain[:logjoint], chain[:logprior] .+ chain[:loglikelihood]) # And that the outcome is not influenced by the likelihood - @test mean(chain, :x) ≈ 0.0 atol = 0.1 - end - end - - @testset "chain ordering" begin - for alg in (Prior(), Emcee(10, 2.0)) - chain_sorted = sample(StableRNG(seed), gdemo_default, alg, 1; sort_chain=true) - @test names(MCMCChains.get_sections(chain_sorted, :parameters)) == [:m, :s] - - chain_unsorted = sample( - StableRNG(seed), gdemo_default, alg, 1; sort_chain=false - ) - @test names(MCMCChains.get_sections(chain_unsorted, :parameters)) == [:s, :m] + @test mean(chain[@varname(x)]) ≈ 0.0 atol = 0.1 end end @testset "chain iteration numbers" begin for alg in (Prior(), Emcee(10, 2.0)) chain = sample(StableRNG(seed), gdemo_default, alg, 10) - @test range(chain) == 1:10 + @test FlexiChains.iter_indices(chain) == 1:10 chain = sample( StableRNG(seed), gdemo_default, alg, 10; discard_initial=5, thinning=2 ) - @test range(chain) == range(6; step=2, length=10) + @test FlexiChains.iter_indices(chain) == range(6; step=2, length=10) end end @@ -241,8 +233,8 @@ using Turing check_numerical(res2, [:y], [0.5]; atol=0.1) # Check that all xs are 1. - @test all(isone, res1[:x]) - @test all(isone, res2[:x]) + @test all(isone, res1[@varname(x)]) + @test all(isone, res2[@varname(x)]) end @testset "beta binomial" begin @@ -268,9 +260,9 @@ using Turing chn_p = sample(StableRNG(seed), testbb(obs), pg, 2_000) chn_g = sample(StableRNG(seed), testbb(obs), gibbs, 2_000) - check_numerical(chn_s, [:p], [meanp]; atol=0.05) - check_numerical(chn_p, [:x], [meanp]; atol=0.1) - check_numerical(chn_g, [:x], [meanp]; atol=0.1) + check_numerical(chn_s, [@varname(p)], [meanp]; atol=0.05) + check_numerical(chn_p, [@varname(x)], [meanp]; atol=0.1) + check_numerical(chn_g, [@varname(x)], [meanp]; atol=0.1) end @testset "forbid global" begin @@ -396,12 +388,12 @@ using Turing # as a statistic (which is the same for all 'iterations'). So we can just pick the # first one. res_smc = sample(StableRNG(seed), test(), smc, N) - @test all(isone, res_smc[:x]) + @test all(isone, res_smc[@varname(x)]) smc_logevidence = first(res_smc[:logevidence]) @test smc_logevidence ≈ 2 * log(0.5) res_pg = sample(StableRNG(seed), test(), pg, 100) - @test all(isone, res_pg[:x]) + @test all(isone, res_pg[@varname(x)]) end @testset "sample" begin @@ -622,7 +614,7 @@ using Turing end # Can't test with HMC/NUTS because some AD backends error; see # https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/802 - @test sample(e(), Prior(), 100) isa MCMCChains.Chains + @test sample(e(), Prior(), 100) isa VNChain end end diff --git a/test/mcmc/abstractmcmc.jl b/test/mcmc/abstractmcmc.jl index d743f25318..08c96074c1 100644 --- a/test/mcmc/abstractmcmc.jl +++ b/test/mcmc/abstractmcmc.jl @@ -2,6 +2,7 @@ module TuringAbstractMCMCTests using AbstractMCMC: AbstractMCMC using DynamicPPL: DynamicPPL +using FlexiChains: Extra using Random: AbstractRNG using Test: @test, @testset, @test_throws using Turing @@ -37,7 +38,7 @@ end vi::DynamicPPL.VarInfo=DynamicPPL.VarInfo(rng, model); kwargs..., ) - return vi, nothing + return DynamicPPL.ParamsWithStats(vi, model), nothing end @testset "init_strategy" begin @@ -55,13 +56,12 @@ end model = coinflip() lptrue = logpdf(Binomial(25, 0.2), 10) let inits = InitFromParams((; p=0.2)) - varinfos = sample(model, spl, 1; initial_params=inits, progress=false) - varinfo = only(varinfos) - @test varinfo[@varname(p)] == 0.2 - @test DynamicPPL.getlogjoint(varinfo) == lptrue + chn = sample(model, spl, 1; initial_params=inits, progress=false) + @test only(chn[@varname(p)]) == 0.2 + @test only(chn[Extra(:logjoint)]) == lptrue # parallel sampling - chains = sample( + chn = sample( model, spl, MCMCThreads(), @@ -70,11 +70,8 @@ end initial_params=fill(inits, 10), progress=false, ) - for c in chains - varinfo = only(c) - @test varinfo[@varname(p)] == 0.2 - @test DynamicPPL.getlogjoint(varinfo) == lptrue - end + @test all(chn[@varname(p)] .== 0.2) + @test all(chn[Extra(:logjoint)] .== lptrue) end # check that Vector no longer works @@ -98,14 +95,13 @@ end InitFromParams(Dict(@varname(s) => 4, @varname(m) => -1)), Dict(@varname(s) => 4, @varname(m) => -1), ) - chain = sample(model, spl, 1; initial_params=inits, progress=false) - varinfo = only(chain) - @test varinfo[@varname(s)] == 4 - @test varinfo[@varname(m)] == -1 - @test DynamicPPL.getlogjoint(varinfo) == lptrue + chn = sample(model, spl, 1; initial_params=inits, progress=false) + @test only(chn[@varname(s)]) == 4 + @test only(chn[@varname(m)]) == -1 + @test only(chn[Extra(:logjoint)]) == lptrue # parallel sampling - chains = sample( + chn = sample( model, spl, MCMCThreads(), @@ -114,12 +110,9 @@ end initial_params=fill(inits, 10), progress=false, ) - for c in chains - varinfo = only(c) - @test varinfo[@varname(s)] == 4 - @test varinfo[@varname(m)] == -1 - @test DynamicPPL.getlogjoint(varinfo) == lptrue - end + @test all(chn[@varname(s)] .== 4) + @test all(chn[@varname(m)] .== -1) + @test all(chn[Extra(:logjoint)] .== lptrue) end # set only m = -1 @@ -134,12 +127,11 @@ end Dict(@varname(m) => -1), ) chain = sample(model, spl, 1; initial_params=inits, progress=false) - varinfo = only(chain) - @test !ismissing(varinfo[@varname(s)]) - @test varinfo[@varname(m)] == -1 + @test !ismissing(only(chain[@varname(s)])) + @test only(chain[@varname(m)]) == -1 # parallel sampling - chains = sample( + c = sample( model, spl, MCMCThreads(), @@ -148,11 +140,8 @@ end initial_params=fill(inits, 10), progress=false, ) - for c in chains - varinfo = only(c) - @test !ismissing(varinfo[@varname(s)]) - @test varinfo[@varname(m)] == -1 - end + @test !any(ismissing, c[@varname(s)]) + @test all(c[@varname(m)] .== -1) end end end diff --git a/test/mcmc/emcee.jl b/test/mcmc/emcee.jl index ee786b8584..5904f29891 100644 --- a/test/mcmc/emcee.jl +++ b/test/mcmc/emcee.jl @@ -6,6 +6,7 @@ using Distributions: sample using DynamicPPL: DynamicPPL using Random: Random, Xoshiro using StableRNGs: StableRNG +using FlexiChains: FlexiChains using Test: @test, @test_throws, @testset using Turing @@ -35,7 +36,7 @@ using Turing chain1 = sample(rng1, gdemo_default, spl, 1) rng2 = Xoshiro(1234) chain2 = sample(rng2, gdemo_default, spl, 1) - @test Array(chain1) == Array(chain2) + @test FlexiChains.has_same_data(chain1, chain2) initial_nt = DynamicPPL.InitFromParams((s=2.0, m=1.0)) # Initial parameters have to be specified for every walker diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index e497fdde3a..ab81287868 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -1,7 +1,7 @@ module ESSTests using ..Models: MoGtest, MoGtest_default, gdemo, gdemo_default -using ..NumericalTests: check_MoGtest_default, check_numerical +using ..NumericalTests: check_MoGtest_default, check_numerical, check_gdemo using ..SamplerTestUtils: test_rng_respected, test_sampler_analytical using Distributions: Normal, sample using DynamicPPL: DynamicPPL @@ -50,18 +50,18 @@ using Turing @testset "demo_default" begin chain = sample(StableRNG(seed), demo_default, ESS(), 5_000) - check_numerical(chain, [:m], [0.8]; atol=0.1) + check_numerical(chain, [@varname(m)], [0.8]; atol=0.1) end @testset "demodot_default" begin chain = sample(StableRNG(seed), demodot_default, ESS(), 5_000) - check_numerical(chain, ["m[1]", "m[2]"], [0.0, 0.8]; atol=0.1) + check_numerical(chain, [@varname(m[1]), @varname(m[2])], [0.0, 0.8]; atol=0.1) end @testset "gdemo with CSMC + ESS" begin alg = Gibbs(:s => CSMC(15), :m => ESS()) chain = sample(StableRNG(seed), gdemo(1.5, 2.0), alg, 3_000) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + check_gdemo(chain; atol=0.1) end @testset "MoGtest_default with CSMC + ESS" begin @@ -97,17 +97,20 @@ using Turing # Test that ESS can sample multiple variables regardless of whether they are under the # same symbol or not. @testset "Multiple variables" begin + zdist = Beta(2.0, 2.0) + ydist = Normal(-3.0, 3.0) + @model function xy() - z ~ Beta(2.0, 2.0) + z ~ zdist x ~ Normal(z, 2.0) - return y ~ Normal(-3.0, 3.0) + return y ~ ydist end @model function x12() - z ~ Beta(2.0, 2.0) + z ~ zdist x = Vector{Float64}(undef, 2) x[1] ~ Normal(z, 2.0) - return x[2] ~ Normal(-3.0, 3.0) + return x[2] ~ ydist end num_samples = 10_000 @@ -117,9 +120,11 @@ using Turing chn1 = sample(StableRNG(23), xy(), spl_xy, num_samples) chn2 = sample(StableRNG(23), x12(), spl_x, num_samples) - @test chn1.value ≈ chn2.value - @test mean(chn1[:z]) ≈ mean(Beta(2.0, 2.0)) atol = 0.05 - @test mean(chn1[:y]) ≈ -3.0 atol = 0.05 + @test chn1[@varname(z)] == chn2[@varname(z)] + @test chn1[@varname(x)] == chn2[@varname(x[1])] + @test chn1[@varname(y)] == chn2[@varname(x[2])] + @test mean(chn1[@varname(z)]) ≈ mean(zdist) atol = 0.05 + @test mean(chn1[@varname(y)]) ≈ mean(ydist) atol = 0.05 end end diff --git a/test/mcmc/external_sampler.jl b/test/mcmc/external_sampler.jl index 38f8f90142..47de30b776 100644 --- a/test/mcmc/external_sampler.jl +++ b/test/mcmc/external_sampler.jl @@ -84,9 +84,9 @@ using Turing.Inference: AdvancedHMC chn = sample( model, externalsampler(MySampler()), 10; initial_params=InitFromParams((a=a, b=b)) ) - @test chn isa MCMCChains.Chains - @test all(chn[:a] .== a) - @test all(chn[:b] .== b) + @test chn isa VNChain + @test all(chn[@varname(a)] .== a) + @test all(chn[@varname(b)] .== b) expected_logpdf = logpdf(Beta(2, 2), a) + logpdf(Normal(a), b) @test all(chn[:logjoint] .== expected_logpdf) @test all(chn[:logprior] .== expected_logpdf) @@ -237,7 +237,7 @@ end sampler; adtype=Turing.DEFAULT_ADTYPE, unconstrained=true ) chn = sample(logp_check(), Gibbs(@varname(x) => sampler_ext), 100) - @test isapprox(logpdf.(Normal(), chn[:x]), chn[:logjoint]) + @test isapprox(logpdf.(Normal(), chn[@varname(x)]), chn[:logjoint]) end end @@ -270,7 +270,7 @@ end sampler = initialize_mh_rw(model) sampler_ext = externalsampler(sampler; unconstrained=true) chn = sample(logp_check(), Gibbs(@varname(x) => sampler_ext), 100) - @test isapprox(logpdf.(Normal(), chn[:x]), chn[:logjoint]) + @test isapprox(logpdf.(Normal(), chn[@varname(x)]), chn[:logjoint]) end end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 34107202da..3b94a91019 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -9,9 +9,11 @@ using ..NumericalTests: two_sample_test import Combinatorics using AbstractMCMC: AbstractMCMC +using AbstractPPL: AbstractPPL using Distributions: InverseGamma, Normal using Distributions: sample using DynamicPPL: DynamicPPL +using FlexiChains: FlexiChains using ForwardDiff: ForwardDiff using Random: Random, Xoshiro using ReverseDiff: ReverseDiff @@ -262,7 +264,7 @@ end ) chain1 = sample(Xoshiro(23), gdemo_default, sampler1, 10) chain2 = sample(Xoshiro(23), gdemo_default, sampler1, 10) - @test chain1.value == chain2.value + @test FlexiChains.has_same_data(chain1, chain2) end @testset "Gibbs warmup" begin @@ -380,12 +382,12 @@ end @varname(s) => RepeatSampler(HMC(0.1, 5), 3), @varname(m) => RepeatSampler(PG(10), 2), ) - @test sample(gdemo_default, s1, N) isa MCMCChains.Chains - @test sample(gdemo_default, s2, N) isa MCMCChains.Chains - @test sample(gdemo_default, s3, N) isa MCMCChains.Chains - @test sample(gdemo_default, s4, N) isa MCMCChains.Chains - @test sample(gdemo_default, s5, N) isa MCMCChains.Chains - @test sample(gdemo_default, s6, N) isa MCMCChains.Chains + @test sample(gdemo_default, s1, N) isa VNChain + @test sample(gdemo_default, s2, N) isa VNChain + @test sample(gdemo_default, s3, N) isa VNChain + @test sample(gdemo_default, s4, N) isa VNChain + @test sample(gdemo_default, s5, N) isa VNChain + @test sample(gdemo_default, s6, N) isa VNChain end # Test various combinations of samplers against models for which we know the analytical @@ -394,28 +396,28 @@ end @testset "CSMC and HMC on gdemo" begin alg = Gibbs(:s => CSMC(15), :m => HMC(0.2, 4)) chain = sample(gdemo(1.5, 2.0), alg, 3_000) - check_numerical(chain, [:m], [7 / 6]; atol=0.15) + check_numerical(chain, [@varname(m)], [7 / 6]; atol=0.15) # Be more relaxed with the tolerance of the variance. - check_numerical(chain, [:s], [49 / 24]; atol=0.35) + check_numerical(chain, [@varname(s)], [49 / 24]; atol=0.35) end @testset "MH and HMCDA on gdemo" begin alg = Gibbs(:s => MH(), :m => HMCDA(200, 0.65, 0.3)) chain = sample(gdemo(1.5, 2.0), alg, 3_000) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + check_gdemo(chain; atol=0.1) end @testset "CSMC and ESS on gdemo" begin alg = Gibbs(:s => CSMC(15), :m => ESS()) chain = sample(gdemo(1.5, 2.0), alg, 3_000) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + check_gdemo(chain; atol=0.1) end # TODO(mhauru) Why is this in the Gibbs test suite? @testset "CSMC on gdemo" begin alg = CSMC(15) chain = sample(gdemo(1.5, 2.0), alg, 4_000) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) + check_gdemo(chain; atol=0.1) end @testset "PG and HMC on MoGtest_default" begin @@ -467,12 +469,7 @@ end model = gdemo_copy() @nospecialize function AbstractMCMC.bundle_samples( - samples::Vector, - ::typeof(model), - ::Gibbs, - state, - ::Type{MCMCChains.Chains}; - kwargs..., + samples::Vector, ::typeof(model), ::Gibbs, state, ::Type{VNChain}; kwargs... ) samples isa Vector{<:DynamicPPL.ParamsWithStats} || error("incorrect transitions") @@ -516,29 +513,23 @@ end @test size(chn, 1) == 1000 # Test that both states are explored (basic functionality test) - b_samples = chn[:b] + b_samples = chn[@varname(b)] unique_b_values = unique(skipmissing(b_samples)) @test length(unique_b_values) >= 1 # At least one value should be sampled # Test that θ[1] values are reasonable when they exist - theta1_samples = collect(skipmissing(chn[:, Symbol("θ[1]"), 1])) - if length(theta1_samples) > 0 - @test all(isfinite, theta1_samples) # All samples should be finite - @test std(theta1_samples) > 0.1 # Should show some variation - end + theta1_samples = collect(skipmissing(chn[@varname(θ[1])])) + @test all(isfinite, theta1_samples) # All samples should be finite + @test std(theta1_samples) > 0.1 # Should show some variation # Test that when b=0, only θ[1] exists, and when b=1, both θ[1] and θ[2] exist - theta2_col_exists = Symbol("θ[2]") in names(chn) - if theta2_col_exists - theta2_samples = chn[:, Symbol("θ[2]"), 1] - # θ[2] should have some missing values (when b=0) and some non-missing (when b=1) - n_missing_theta2 = sum(ismissing.(theta2_samples)) - n_present_theta2 = sum(.!ismissing.(theta2_samples)) - - # At least some θ[2] values should be missing (corresponding to b=0 states) - # This is a basic structural test - we're not testing exact analytical results - @test n_missing_theta2 > 0 || n_present_theta2 > 0 # One of these should be true - end + theta2_samples = chn[@varname(θ[2])] + # θ[2] should have some missing values (when b=0) and some non-missing (when b=1) + n_missing_theta2 = sum(ismissing.(theta2_samples)) + n_present_theta2 = sum(.!ismissing.(theta2_samples)) + # At least some θ[2] values should be missing (corresponding to b=0 states) + # This is a basic structural test - we're not testing exact analytical results + @test n_missing_theta2 > 0 || n_present_theta2 > 0 # One of these should be true end @testset "Demo model" begin @@ -606,18 +597,27 @@ end thinning=thinning, ) + # Extract varname leaves. + vns = DynamicPPL.TestUtils.varnames(model) + vn_leaves = Set{DynamicPPL.VarName}() + for vn in vns + val = first(chain[vn]) + leaves = AbstractPPL.varname_leaves(vn, val) + vn_leaves = union(vn_leaves, leaves) + end + # Perform KS test to ensure that the chains are similar. - xs = Array(chain) - xs_true = Array(chain_true) - for i in 1:size(xs, 2) - @test two_sample_test(xs[:, i], xs_true[:, i]; warn_on_fail=true) + for vn in vn_leaves + vals = vec(chain[vn]) + true_vals = vec(chain_true[vn]) + @test two_sample_test(vals, true_vals; warn_on_fail=true) # Let's make sure that the significance level is not too low by # checking that the KS test fails for some simple transformations. # TODO: Replace the heuristic below with closed-form implementations # of the targets, once they are implemented in DynamicPPL. - @test !two_sample_test(0.9 .* xs_true[:, i], xs_true[:, i]) - @test !two_sample_test(1.1 .* xs_true[:, i], xs_true[:, i]) - @test !two_sample_test(1e-1 .+ xs_true[:, i], xs_true[:, i]) + @test !two_sample_test(0.9 .* true_vals, true_vals) + @test !two_sample_test(1.1 .* true_vals, true_vals) + @test !two_sample_test(1e-1 .+ true_vals, true_vals) end end end @@ -671,9 +671,9 @@ end end model = model1() spl = Gibbs(@varname(x[1]) => HMC(0.5, 10), @varname(y.a) => MH()) - @test sample(model, spl, 10) isa MCMCChains.Chains + @test sample(model, spl, 10) isa VNChain spl = Gibbs((@varname(x[1]), @varname(y.a)) => HMC(0.5, 10)) - @test sample(model, spl, 10) isa MCMCChains.Chains + @test sample(model, spl, 10) isa VNChain end @testset "submodels" begin @@ -687,9 +687,9 @@ end spl = Gibbs( @varname(a.x) => HMC(0.5, 10), @varname(b.x) => MH(), @varname(x) => MH() ) - @test sample(model, spl, 10) isa MCMCChains.Chains + @test sample(model, spl, 10) isa VNChain spl = Gibbs((@varname(a.x), @varname(b.x), @varname(x)) => MH()) - @test sample(model, spl, 10) isa MCMCChains.Chains + @test sample(model, spl, 10) isa VNChain end @testset "CSMC + ESS" begin @@ -746,7 +746,7 @@ end chn = sample( logp_check(), Gibbs(@varname(x) => sampler), 100; progress=false ) - @test isapprox(logpdf.(Normal(), chn[:x]), chn[:logjoint]) + @test isapprox(logpdf.(Normal(), chn[@varname(x)]), chn[:logjoint]) end end diff --git a/test/mcmc/gibbs_conditional.jl b/test/mcmc/gibbs_conditional.jl index 22b04c30ba..909080722b 100644 --- a/test/mcmc/gibbs_conditional.jl +++ b/test/mcmc/gibbs_conditional.jl @@ -51,21 +51,24 @@ using Turing ) chain = sample(StableRNG(23), model, sampler, 1_000) @test size(chain, 1) == 1_000 - @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.1 - @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.1 + @test mean(chain[@varname(precision)]) ≈ mean(reference_chain[@varname(precision)]) atol = + 0.1 + @test mean(chain[@varname(m)]) ≈ mean(reference_chain[@varname(m)]) atol = 0.1 # Mix GibbsConditional with an MCMC sampler sampler = Gibbs(:precision => GibbsConditional(cond_precision), :m => MH()) chain = sample(StableRNG(23), model, sampler, 1_000) @test size(chain, 1) == 1_000 - @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.1 - @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.1 + @test mean(chain[@varname(precision)]) ≈ mean(reference_chain[@varname(precision)]) atol = + 0.1 + @test mean(chain[@varname(m)]) ≈ mean(reference_chain[@varname(m)]) atol = 0.1 sampler = Gibbs(:m => GibbsConditional(cond_m), :precision => HMC(0.1, 10)) chain = sample(StableRNG(23), model, sampler, 1_000) @test size(chain, 1) == 1_000 - @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.1 - @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.1 + @test mean(chain[@varname(precision)]) ≈ mean(reference_chain[@varname(precision)]) atol = + 0.1 + @test mean(chain[@varname(m)]) ≈ mean(reference_chain[@varname(m)]) atol = 0.1 # Block sample, sampling the same variable with multiple component samplers. sampler = Gibbs( @@ -80,8 +83,9 @@ using Turing ) chain = sample(StableRNG(23), model, sampler, 1_000) @test size(chain, 1) == 1_000 - @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.1 - @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.1 + @test mean(chain[@varname(precision)]) ≈ mean(reference_chain[@varname(precision)]) atol = + 0.1 + @test mean(chain[@varname(m)]) ≈ mean(reference_chain[@varname(m)]) atol = 0.1 end @testset "Simple normal model" begin @@ -114,7 +118,7 @@ using Turing chain = sample(StableRNG(23), model, sampler, 1_000) # The correct posterior mean isn't true_mean, but it is very close, because we # have a lot of data. - @test mean(chain, :mean) ≈ true_mean atol = 0.05 + @test mean(chain[@varname(mean)]) ≈ true_mean atol = 0.05 end @testset "Double simple normal" begin @@ -165,8 +169,8 @@ using Turing chain = sample(StableRNG(23), model, sampler, 1_000) # The correct posterior mean isn't true_mean, but it is very close, because we # have a lot of data. - @test mean(chain, :mean1) ≈ true_mean1 atol = 0.1 - @test mean(chain, :mean2) ≈ true_mean2 atol = 0.1 + @test mean(chain[@varname(mean1)]) ≈ true_mean1 atol = 0.1 + @test mean(chain[@varname(mean2)]) ≈ true_mean2 atol = 0.1 # Test using GibbsConditional for both in a block, returning a Dict. function cond_mean_dict(c) @@ -182,8 +186,8 @@ using Turing (:var1, :var2) => HMC(0.1, 10), ) chain = sample(StableRNG(23), model, sampler, 1_000) - @test mean(chain, :mean1) ≈ true_mean1 atol = 0.1 - @test mean(chain, :mean2) ≈ true_mean2 atol = 0.1 + @test mean(chain[@varname(mean1)]) ≈ true_mean1 atol = 0.1 + @test mean(chain[@varname(mean2)]) ≈ true_mean2 atol = 0.1 # As above but with a NamedTuple rather than a Dict. function cond_mean_nt(c) @@ -197,8 +201,8 @@ using Turing (:var1, :var2) => HMC(0.1, 10), ) chain = sample(StableRNG(23), model, sampler, 1_000) - @test mean(chain, :mean1) ≈ true_mean1 atol = 0.1 - @test mean(chain, :mean2) ≈ true_mean2 atol = 0.1 + @test mean(chain[@varname(mean1)]) ≈ true_mean1 atol = 0.1 + @test mean(chain[@varname(mean2)]) ≈ true_mean2 atol = 0.1 end # Test simultaneously conditioning and fixing variables. @@ -219,14 +223,14 @@ using Turing :var2 => HMC(0.1, 10), ) chain = sample(StableRNG(23), model_condition_fix, sampler, 10_000) - @test mean(chain, :mean1) ≈ 0.0 atol = 0.1 - @test mean(chain, :mean2) ≈ true_mean2 atol = 0.1 + @test mean(chain[@varname(mean1)]) ≈ 0.0 atol = 0.1 + @test mean(chain[@varname(mean2)]) ≈ true_mean2 atol = 0.1 # As above, but reverse the order of condition and fix. model_fix_condition = fix(condition(base_model; x2=x2_obs); x1=x1_obs) chain = sample(StableRNG(23), model_condition_fix, sampler, 10_000) - @test mean(chain, :mean1) ≈ 0.0 atol = 0.1 - @test mean(chain, :mean2) ≈ true_mean2 atol = 0.1 + @test mean(chain[@varname(mean1)]) ≈ 0.0 atol = 0.1 + @test mean(chain[@varname(mean2)]) ≈ true_mean2 atol = 0.1 end end @@ -266,9 +270,9 @@ using Turing (@varname(a[1]), @varname(a[2]), @varname(a[3])) => ESS(), ) chain = sample(StableRNG(23), m, sampler, 10_000) - @test mean(chain, Symbol("b[1]")) ≈ 0.0 atol = 0.05 - @test mean(chain, Symbol("b[2]")) ≈ 10.0 atol = 0.05 - @test mean(chain, Symbol("b[3]")) ≈ 20.0 atol = 0.05 + @test mean(chain[@varname(b[1])]) ≈ 0.0 atol = 0.05 + @test mean(chain[@varname(b[2])]) ≈ 10.0 atol = 0.05 + @test mean(chain[@varname(b[3])]) ≈ 20.0 atol = 0.05 condvals = @vnt begin @template a = zeros(3) @@ -285,9 +289,9 @@ using Turing @varname(a[3]) => ESS(), ) chain = sample(StableRNG(23), m_condfix, sampler, 10_000) - @test mean(chain, Symbol("b[1]")) ≈ 100.0 atol = 0.05 - @test mean(chain, Symbol("b[2]")) ≈ 200.0 atol = 0.05 - @test mean(chain, Symbol("b[3]")) ≈ 20.0 atol = 0.05 + @test mean(chain[@varname(b[1])]) ≈ 100.0 atol = 0.05 + @test mean(chain[@varname(b[2])]) ≈ 200.0 atol = 0.05 + @test mean(chain[@varname(b[3])]) ≈ 20.0 atol = 0.05 end @testset "Helpful error outside Gibbs" begin diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 5bba98bc04..83db1dabbb 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -5,6 +5,7 @@ using ..NumericalTests: check_gdemo, check_numerical using Bijectors: Bijectors using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample using DynamicPPL: DynamicPPL +using FlexiChains: FlexiChains import ForwardDiff using HypothesisTests: ApproximateTwoSampleKSTest, pvalue import ReverseDiff @@ -54,7 +55,9 @@ using Turing chain = sample(StableRNG(seed), constrained_simplex_test(obs12), HMC(0.75, 2), 1000) - check_numerical(chain, ["ps[1]", "ps[2]"], [5 / 16, 11 / 16]; atol=0.015) + check_numerical( + chain, [@varname(ps[1]), @varname(ps[2])], [5 / 16, 11 / 16]; atol=0.015 + ) end # Test the sampling of a matrix-value distribution. @@ -65,12 +68,7 @@ using Turing n_samples = 1_000 chain = sample(StableRNG(24), model_f, HMC(0.15, 7), n_samples) - # Reshape the chain into an array of 2x2 matrices, one per sample. Then compute - # the average of the samples, as a matrix - r = reshape(Array(chain), n_samples, 2, 2) - r_mean = dropdims(mean(r; dims=1); dims=1) - - @test isapprox(r_mean, mean(dist); atol=0.2) + @test isapprox(mean(chain[@varname(v)]), mean(dist); atol=0.2) end @testset "multivariate support" begin @@ -156,9 +154,9 @@ using Turing alg1 = Gibbs(:m => PG(10), :s => NUTS(100, 0.65)) alg2 = Gibbs(:m => PG(10), :s => HMC(0.1, 3)) alg3 = Gibbs(:m => PG(10), :s => HMCDA(100, 0.65, 0.3)) - @test sample(StableRNG(seed), gdemo_default, alg1, 10) isa Chains - @test sample(StableRNG(seed), gdemo_default, alg2, 10) isa Chains - @test sample(StableRNG(seed), gdemo_default, alg3, 10) isa Chains + @test sample(StableRNG(seed), gdemo_default, alg1, 10) isa VNChain + @test sample(StableRNG(seed), gdemo_default, alg2, 10) isa VNChain + @test sample(StableRNG(seed), gdemo_default, alg3, 10) isa VNChain end # issue #1923 @@ -167,7 +165,8 @@ using Turing res1 = sample(StableRNG(seed), gdemo_default, alg, 10) res2 = sample(StableRNG(seed), gdemo_default, alg, 10) res3 = sample(StableRNG(seed), gdemo_default, alg, 10) - @test Array(res1) == Array(res2) == Array(res3) + @test FlexiChains.has_same_data(res1, res2) + @test FlexiChains.has_same_data(res1, res3) end @testset "initial params are respected" begin @@ -235,7 +234,7 @@ using Turing 10; nadapts=0, discard_adapt=false, - initial_state=loadstate(chn1), + initial_state=only(loadstate(chn1)), ) # if chn2 uses initial_state, its first sample should be somewhere around 5. if # initial_state isn't used, it will be sampled from [-2, 2] so this test should fail @@ -250,7 +249,8 @@ using Turing end model = vector_of_dirichlet() chain = sample(model, NUTS(), 1_000) - @test mean(Array(chain)) ≈ 0.2 + xs = vcat(stack(chain[@varname(xs[1])]), stack(chain[@varname(xs[2])])) + @test mean(xs) ≈ 0.2 end @testset "issue: #2195" begin diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index fecd5804ff..f9bbc09544 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -222,8 +222,8 @@ GKernel(variance, vn) = (vnt -> Normal(vnt[vn], sqrt(variance))) chn_big = sample(StableRNG(seed), mod, alg_big, 1_000) # Test that the small variance version is actually smaller. - variance_small = var(diff(Array(chn_small["μ[1]"]); dims=1)) - variance_big = var(diff(Array(chn_big["μ[1]"]); dims=1)) + variance_small = var(diff(chn_small[@varname(μ[1])]; dims=1)) + variance_big = var(diff(chn_big[@varname(μ[1])]; dims=1)) @test variance_small < variance_big / 100.0 end @@ -237,12 +237,12 @@ GKernel(variance, vn) = (vnt -> Normal(vnt[vn], sqrt(variance))) chain = sample(StableRNG(seed), test(1), MH(), 5_000) for i in 1:5 - @test mean(chain, "T[1][$i]") ≈ 0.2 atol = 0.01 + @test mean(chain[@varname(T[1][i])]) ≈ 0.2 atol = 0.01 end chain = sample(StableRNG(seed), test(10), MH(), 5_000) for j in 1:10, i in 1:5 - @test mean(chain, "T[$j][$i]") ≈ 0.2 atol = 0.01 + @test mean(chain[@varname(T[j][i])]) ≈ 0.2 atol = 0.01 end end @@ -252,11 +252,12 @@ GKernel(variance, vn) = (vnt -> Normal(vnt[vn], sqrt(variance))) chain = sample(StableRNG(seed), f(), MH(), 5_000) indices = [(1, 1), (2, 1), (2, 2)] values = [1, 0, 0.785] + uplo_sym = uplo == 'U' ? :U : :L for ((i, j), v) in zip(indices, values) if uplo == 'U' # Transpose - @test mean(chain, "x.$uplo[$j, $i]") ≈ v atol = 0.01 + @test mean(chain[@varname(x.$uplo_sym[j, i])]) ≈ v atol = 0.01 else - @test mean(chain, "x.$uplo[$i, $j]") ≈ v atol = 0.01 + @test mean(chain[@varname(x.$uplo_sym[i, j])]) ≈ v atol = 0.01 end end end diff --git a/test/mcmc/particle_mcmc.jl b/test/mcmc/particle_mcmc.jl index 8b3c68ff15..19e6655355 100644 --- a/test/mcmc/particle_mcmc.jl +++ b/test/mcmc/particle_mcmc.jl @@ -4,6 +4,7 @@ using ..Models: gdemo_default using ..SamplerTestUtils: test_chain_logp_metadata using AdvancedPS: ResampleWithESSThreshold, resample_systematic, resample_multinomial using Distributions: Bernoulli, Beta, Gamma, Normal, sample +using FlexiChains: VNChain using Random: Random using StableRNGs: StableRNG using Test: @test, @test_logs, @test_throws, @testset @@ -102,19 +103,19 @@ using Turing @test_logs (:warn, r"ignored") sample(normal(), SMC(), 10; discard_initial=5) chn = sample(normal(), SMC(), 10; discard_initial=5) @test size(chn, 1) == 10 - @test chn isa MCMCChains.Chains + @test chn isa VNChain @test_logs (:warn, r"ignored") sample(normal(), SMC(), 10; thinning=3) chn2 = sample(normal(), SMC(), 10; thinning=3) @test size(chn2, 1) == 10 - @test chn2 isa MCMCChains.Chains + @test chn2 isa VNChain @test_logs (:warn, r"ignored") sample( normal(), SMC(), 10; discard_initial=2, thinning=2 ) chn3 = sample(normal(), SMC(), 10; discard_initial=2, thinning=2) @test size(chn3, 1) == 10 - @test chn3 isa MCMCChains.Chains + @test chn3 isa VNChain end end @@ -193,11 +194,11 @@ end end chain = sample(StableRNG(468), kwarg_demo(5.0), PG(20), 1000) - @test chain isa MCMCChains.Chains + @test chain isa VNChain @test mean(chain[:x]) ≈ 2.5 atol = 0.3 chain2 = sample(StableRNG(468), kwarg_demo(5.0; n=10.0), PG(20), 1000) - @test chain2 isa MCMCChains.Chains + @test chain2 isa VNChain @test mean(chain2[:x]) ≈ 7.5 atol = 0.3 end diff --git a/test/mcmc/repeat_sampler.jl b/test/mcmc/repeat_sampler.jl index 1a22884029..a88a9aeee5 100644 --- a/test/mcmc/repeat_sampler.jl +++ b/test/mcmc/repeat_sampler.jl @@ -1,7 +1,7 @@ module RepeatSamplerTests +import FlexiChains using ..Models: gdemo_default -using MCMCChains: MCMCChains using Random: Xoshiro using Test: @test, @testset using Turing @@ -35,10 +35,9 @@ using Turing num_samples, num_chains, ) - # isequal to avoid comparing `missing`s in chain stats - @test chn1 isa MCMCChains.Chains - @test chn2 isa MCMCChains.Chains - @test isequal(chn1.value, chn2.value) + @test chn1 isa VNChain + @test chn2 isa VNChain + @test FlexiChains.has_same_data(chn1, chn2) end end diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index c4f1e48c06..63a7c5451e 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -44,9 +44,9 @@ end check_gdemo(chain; atol=0.25) # Weight samples by step sizes (cf section 4.2 in the paper by Welling and Teh) - v = get(chain, [:SGLD_stepsize, :s, :m]) - s_weighted = dot(v.SGLD_stepsize, v.s) / sum(v.SGLD_stepsize) - m_weighted = dot(v.SGLD_stepsize, v.m) / sum(v.SGLD_stepsize) + ss = chain[:SGLD_stepsize] + s_weighted = dot(ss, chain[@varname(s)]) / sum(ss) + m_weighted = dot(ss, chain[@varname(m)]) / sum(ss) @test s_weighted ≈ 49 / 24 atol = 0.2 @test m_weighted ≈ 7 / 6 atol = 0.2 end diff --git a/test/mcmc/utilities.jl b/test/mcmc/utilities.jl index 4ed3ac30e4..c04bbadf8a 100644 --- a/test/mcmc/utilities.jl +++ b/test/mcmc/utilities.jl @@ -1,15 +1,14 @@ module MCMCUtilitiesTests using ..Models: gdemo_default +using FlexiChains: FlexiChains using Test: @test, @testset using Turing @testset "Timer" begin chain = sample(gdemo_default, MH(), 1000) - - @test chain.info.start_time isa Float64 - @test chain.info.stop_time isa Float64 - @test chain.info.start_time ≤ chain.info.stop_time + @test FlexiChains.sampling_time(chain) isa Vector{Float64} + @test only(FlexiChains.sampling_time(chain)) > 0.0 end end diff --git a/test/stdlib/RandomMeasures.jl b/test/stdlib/RandomMeasures.jl index 63c6e42422..8de9a6a42c 100644 --- a/test/stdlib/RandomMeasures.jl +++ b/test/stdlib/RandomMeasures.jl @@ -59,8 +59,7 @@ using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess model_fun = infiniteGMM(data) chain = sample(model_fun, SMC(), iterations) - @test chain isa MCMCChains.Chains - @test eltype(chain.value) === Union{Float64,Missing} + @test chain isa VNChain end # partitions = [ # [[1, 2, 3, 4]], diff --git a/test/stdlib/distributions.jl b/test/stdlib/distributions.jl index 56c2e59b13..334bbfde14 100644 --- a/test/stdlib/distributions.jl +++ b/test/stdlib/distributions.jl @@ -1,6 +1,5 @@ module DistributionsTests -using ..NumericalTests: check_dist_numerical using Distributions using LinearAlgebra: I using Random: Random @@ -9,11 +8,37 @@ using StatsFuns: logistic using Test: @testset, @test using Turing +function check_dist_numerical( + dist, chn; mean_atol=0.1, mean_rtol=0.1, var_atol=1.0, var_rtol=0.5 +) + @testset "numerical" begin + # Extract values. + chn_xs = chn[@varname(x)] + + # Check means. + dist_mean = mean(dist) + if !all(isnan, dist_mean) && !all(isinf, dist_mean) + chn_mean = mean(chn_xs) + @test chn_mean ≈ dist_mean atol = mean_atol rtol = mean_rtol + end + + # Check variances. + # var() for Distributions.MatrixDistribution is not defined + if !(dist isa Distributions.MatrixDistribution) + # Variance + dist_var = var(dist) + if !all(isnan, dist_var) && !all(isinf, dist_var) + chn_var = var(chn_xs) + @test chn_var ≈ chn_var atol = var_atol rtol = var_rtol + end + end + end +end + @testset "distributions.jl" begin - rng = StableRNG(12345) @testset "distributions functions" begin ns = 10 - logitp = randn(rng) + logitp = randn() d1 = BinomialLogit(ns, logitp) d2 = Binomial(ns, logistic(logitp)) k = 3 @@ -24,7 +49,7 @@ using Turing d = OrderedLogistic(-2, [-1, 1]) n = 1_000_000 - y = rand(rng, d, n) + y = rand(d, n) K = length(d.cutpoints) + 1 p = [mean(==(k), y) for k in 1:K] # empirical probs pmf = [exp(logpdf(d, k)) for k in 1:K] @@ -52,9 +77,10 @@ using Turing @testset "single distribution correctness" begin n_samples = 10_000 - mean_tol = 0.1 + mean_atol = 0.1 + mean_rtol = 0.1 var_atol = 1.0 - var_tol = 0.5 + var_rtol = 0.5 multi_dim = 4 # 1. UnivariateDistribution # NOTE: Noncentral distributions are commented out because of @@ -130,22 +156,22 @@ using Turing @model m() = x ~ dist - seed = if dist isa GeneralizedExtremeValue - # GEV is prone to giving really wacky results that are quite - # seed-dependent. - StableRNG(469) - else - StableRNG(468) - end - chn = sample(seed, m(), HMC(0.05, 20), n_samples) + # Note: GeneralizedExtremeValue is prone to giving really wacky + # results that are quite seed-dependent. Do not hesitate to + # change the seed if the test fails, even by a large margin. + seed = StableRNG(468) + chn = sample( + seed, m(), HMC(0.05, 20), n_samples; progress=false + ) # Numerical tests. check_dist_numerical( dist, chn; - mean_tol=mean_tol, + mean_atol=mean_atol, + mean_rtol=mean_rtol, var_atol=var_atol, - var_tol=var_tol, + var_rtol=var_rtol, ) end end diff --git a/test/test_utils/numerical_tests.jl b/test/test_utils/numerical_tests.jl index 97d1740141..1322c5362f 100644 --- a/test/test_utils/numerical_tests.jl +++ b/test/test_utils/numerical_tests.jl @@ -1,71 +1,41 @@ module NumericalTests using Distributions -using MCMCChains: namesingroup using Test: @test, @testset +using Turing: @varname using HypothesisTests: HypothesisTests export check_MoGtest_default, check_MoGtest_default_z_vector, check_dist_numerical, check_gdemo, check_numerical -function check_dist_numerical(dist, chn; mean_tol=0.1, var_atol=1.0, var_tol=0.5) - @testset "numerical" begin - # Extract values. - chn_xs = Array(chn[1:2:end, namesingroup(chn, :x), :]) - - # Check means. - dist_mean = mean(dist) - mean_shape = size(dist_mean) - if !all(isnan, dist_mean) && !all(isinf, dist_mean) - chn_mean = vec(mean(chn_xs; dims=1)) - chn_mean = length(chn_mean) == 1 ? chn_mean[1] : reshape(chn_mean, mean_shape) - atol_m = if length(chn_mean) > 1 - mean_tol * length(chn_mean) - else - max(mean_tol, mean_tol * chn_mean) - end - @test chn_mean ≈ dist_mean atol = atol_m - end - - # Check variances. - # var() for Distributions.MatrixDistribution is not defined - if !(dist isa Distributions.MatrixDistribution) - # Variance - dist_var = var(dist) - var_shape = size(dist_var) - if !all(isnan, dist_var) && !all(isinf, dist_var) - chn_var = vec(var(chn_xs; dims=1)) - chn_var = length(chn_var) == 1 ? chn_var[1] : reshape(chn_var, var_shape) - atol_v = if length(chn_mean) > 1 - mean_tol * length(chn_mean) - else - max(mean_tol, mean_tol * chn_mean) - end - @test chn_mean ≈ dist_mean atol = atol_v - end - end - end -end - # Helper function for numerical tests -function check_numerical(chain, symbols::Vector, exact_vals::Vector; atol=0.2, rtol=0.0) - for (sym, val) in zip(symbols, exact_vals) - E = val isa Real ? mean(chain[sym]) : vec(mean(chain[sym]; dims=1)) - @info (symbol=sym, exact=val, evaluated=E) +function check_numerical(chain, varnames::Vector, exact_vals::Vector; atol=0.2, rtol=0.0) + for (vn, val) in zip(varnames, exact_vals) + E = val isa Real ? mean(chain[vn]) : vec(mean(chain[vn]; dims=1)) + @info (varname=vn, exact=val, evaluated=E) @test E ≈ val atol = atol rtol = rtol end end # Wrapper function to quickly check gdemo accuracy. function check_gdemo(chain; atol=0.2, rtol=0.0) - return check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=atol, rtol=rtol) + return check_numerical( + chain, [@varname(s), @varname(m)], [49 / 24, 7 / 6]; atol=atol, rtol=rtol + ) end # Wrapper function to check MoGtest. function check_MoGtest_default(chain; atol=0.2, rtol=0.0) return check_numerical( chain, - [:z1, :z2, :z3, :z4, :mu1, :mu2], + [ + @varname(z1), + @varname(z2), + @varname(z3), + @varname(z4), + @varname(mu1), + @varname(mu2) + ], [1.0, 1.0, 2.0, 2.0, 1.0, 4.0]; atol=atol, rtol=rtol, @@ -75,7 +45,14 @@ end function check_MoGtest_default_z_vector(chain; atol=0.2, rtol=0.0) return check_numerical( chain, - [Symbol("z[1]"), Symbol("z[2]"), Symbol("z[3]"), Symbol("z[4]"), :mu1, :mu2], + [ + @varname(z[1]), + @varname(z[2]), + @varname(z[3]), + @varname(z[4]), + @varname(mu1), + @varname(mu2) + ], [1.0, 1.0, 2.0, 2.0, 1.0, 4.0]; atol=atol, rtol=rtol, diff --git a/test/test_utils/sampler.jl b/test/test_utils/sampler.jl index f2a15047e1..70f7bacd51 100644 --- a/test/test_utils/sampler.jl +++ b/test/test_utils/sampler.jl @@ -21,8 +21,8 @@ function test_chain_logp_metadata(spl) end chn = sample(f(), spl, 100) # Check that the log-prior term is calculated in unlinked space. - @test chn[:logprior] ≈ logpdf.(LogNormal(), chn[:x]) - @test chn[:loglikelihood] ≈ logpdf.(Normal.(chn[:x]), 1.0) + @test chn[:logprior] ≈ logpdf.(LogNormal(), chn[@varname(x)]) + @test chn[:loglikelihood] ≈ logpdf.(Normal.(chn[@varname(x)]), 1.0) # This should always be true, but it also indirectly checks that the # log-joint is also calculated in unlinked space. @test chn[:logjoint] ≈ chn[:logprior] + chn[:loglikelihood] @@ -95,7 +95,7 @@ function test_sampler_analytical( for vn_leaf in AbstractPPL.varname_leaves(vn, AbstractPPL.getvalue(target_values, vn)) target_value = AbstractPPL.getvalue(target_values, vn_leaf) - chain_mean_value = mean(chain[Symbol(vn_leaf)]) + chain_mean_value = mean(chain[vn_leaf]) @test chain_mean_value ≈ target_value atol = atol rtol = rtol end end diff --git a/test/variational/vi.jl b/test/variational/vi.jl index fff1d0118b..069c076c44 100644 --- a/test/variational/vi.jl +++ b/test/variational/vi.jl @@ -1,4 +1,3 @@ - module AdvancedVITests using ..Models: gdemo_default @@ -6,17 +5,16 @@ using ..NumericalTests: check_gdemo using AdvancedVI using Bijectors: Bijectors -using Distributions: Dirichlet, Normal +using FlexiChains: FlexiChain, Parameter using LinearAlgebra -using MCMCChains: Chains using Random -using ReverseDiff +import ReverseDiff using StableRNGs: StableRNG using Test: @test, @testset, @test_throws using Turing using Turing.Variational -begin +@testset verbose = true "variational/vi.jl" begin adtype = AutoReverseDiff() operator = AdvancedVI.ClipScale() @@ -103,8 +101,12 @@ begin ) N = 1000 - samples = transpose(rand(rng, q, N)) - chn = Chains(reshape(samples, size(samples)..., 1), ["s", "m"]) + # 2 * N matrix + samples = rand(rng, q, N) + samples_dict = Dict( + Parameter(@varname(s)) => samples[1, :], Parameter(@varname(m)) => samples[2, :] + ) + chn = VNChain(N, 1, samples_dict) check_gdemo(chn; atol=0.5) end