Skip to content

Commit 02a1614

Browse files
committed
address comments
1 parent 3fcdee8 commit 02a1614

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

numpyro/infer/mixed_hmc.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,11 @@ class MixedHMC(DiscreteHMCGibbs):
3434
:param int num_discrete_updates: Number of times to update discrete variables.
3535
Defaults to the number of discrete latent variables.
3636
:param bool random_walk: If False, Gibbs sampling will be used to draw a sample from the
37-
conditional `p(gibbs_site | remaining sites)`. Otherwise, a sample will be drawn uniformly
37+
conditional `p(gibbs_site | remaining sites)`, where `gibbs_site` is one of the
38+
discrete sample sites in the model. Otherwise, a sample will be drawn uniformly
3839
from the domain of `gibbs_site`. Defaults to False.
3940
:param bool modified: whether to use a modified proposal, as suggested in reference [2], which
40-
always proposes a new state for the current Gibbs site. Defaults to True.
41+
always proposes a new state for the current Gibbs site (i.e. discrete site). Defaults to True.
4142
The modified scheme appears in the literature under the name "modified Gibbs sampler" or
4243
"Metropolised Gibbs sampler".
4344
@@ -61,9 +62,10 @@ class MixedHMC(DiscreteHMCGibbs):
6162
>>> mcmc = MCMC(kernel, 1000, 100000, progress_bar=False)
6263
>>> mcmc.run(random.PRNGKey(0), probs, locs)
6364
>>> mcmc.print_summary() # doctest: +SKIP
64-
>>> samples = mcmc.get_samples()["x"]
65-
>>> assert abs(jnp.mean(samples) - 1.3) < 0.1
66-
>>> assert abs(jnp.var(samples) - 4.36) < 0.5
65+
>>> samples = mcmc.get_samples()
66+
>>> assert "x" in samples and "c" in samples
67+
>>> assert abs(jnp.mean(samples["x"]) - 1.3) < 0.1
68+
>>> assert abs(jnp.var(samples["x"]) - 4.36) < 0.5
6769
"""
6870

6971
def __init__(self, inner_kernel, *, num_discrete_updates=None, random_walk=False, modified=False):
@@ -147,7 +149,7 @@ def body_fn(i, vals):
147149
arrival_times = ops.index_update(arrival_times, idx, 1.)
148150

149151
# this is a trick, so that in a sub-trajectory of HMC, we always accept the new proposal
150-
pe = jnp.inf if self.inner_kernel._algo == "HMC" else hmc_state.potential_energy
152+
pe = jnp.inf
151153
hmc_state = hmc_state._replace(trajectory_length=trajectory_length, potential_energy=pe)
152154
# Algo 1, line 7: perform a sub-trajectory
153155
hmc_state = update_continuous(hmc_state, z_discrete)

0 commit comments

Comments
 (0)