@@ -34,10 +34,11 @@ class MixedHMC(DiscreteHMCGibbs):
3434 :param int num_discrete_updates: Number of times to update discrete variables.
3535 Defaults to the number of discrete latent variables.
3636 :param bool random_walk: If False, Gibbs sampling will be used to draw a sample from the
37- conditional `p(gibbs_site | remaining sites)`. Otherwise, a sample will be drawn uniformly
37+ conditional `p(gibbs_site | remaining sites)`, where `gibbs_site` is one of the
38+ discrete sample sites in the model. Otherwise, a sample will be drawn uniformly
3839 from the domain of `gibbs_site`. Defaults to False.
3940 :param bool modified: whether to use a modified proposal, as suggested in reference [2], which
40- always proposes a new state for the current Gibbs site. Defaults to True.
41+ always proposes a new state for the current Gibbs site (i.e. discrete site) . Defaults to True.
4142 The modified scheme appears in the literature under the name "modified Gibbs sampler" or
4243 "Metropolised Gibbs sampler".
4344
@@ -61,9 +62,10 @@ class MixedHMC(DiscreteHMCGibbs):
6162 >>> mcmc = MCMC(kernel, 1000, 100000, progress_bar=False)
6263 >>> mcmc.run(random.PRNGKey(0), probs, locs)
6364 >>> mcmc.print_summary() # doctest: +SKIP
64- >>> samples = mcmc.get_samples()["x"]
65- >>> assert abs(jnp.mean(samples) - 1.3) < 0.1
66- >>> assert abs(jnp.var(samples) - 4.36) < 0.5
65+ >>> samples = mcmc.get_samples()
66+ >>> assert "x" in samples and "c" in samples
67+ >>> assert abs(jnp.mean(samples["x"]) - 1.3) < 0.1
68+ >>> assert abs(jnp.var(samples["x"]) - 4.36) < 0.5
6769 """
6870
6971 def __init__ (self , inner_kernel , * , num_discrete_updates = None , random_walk = False , modified = False ):
@@ -147,7 +149,7 @@ def body_fn(i, vals):
147149 arrival_times = ops .index_update (arrival_times , idx , 1. )
148150
149151 # this is a trick, so that in a sub-trajectory of HMC, we always accept the new proposal
150- pe = jnp .inf if self . inner_kernel . _algo == "HMC" else hmc_state . potential_energy
152+ pe = jnp .inf
151153 hmc_state = hmc_state ._replace (trajectory_length = trajectory_length , potential_energy = pe )
152154 # Algo 1, line 7: perform a sub-trajectory
153155 hmc_state = update_continuous (hmc_state , z_discrete )
0 commit comments