Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 311c901

Browse files
Smit-createrlouf
authored andcommitted
Add tests for geometric JAX samples
1 parent 02bb94f commit 311c901

1 file changed

Lines changed: 15 additions & 0 deletions

File tree

tests/link/jax/test_random.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,21 @@ def test_random_bernoulli(size):
369369
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
370370

371371

372+
@pytest.mark.parametrize(
373+
"p, size",
374+
[
375+
(0.6, ()),
376+
(0.2, (4,)),
377+
],
378+
)
379+
def test_random_geometric(p, size):
380+
rng = shared(np.random.RandomState(123))
381+
g = at.random.geometric(p, size=(1000,) + size, rng=rng)
382+
g_fn = function([], g, mode=jax_mode)
383+
samples = g_fn()
384+
np.testing.assert_allclose(samples.mean(), 1 / p, atol=0.1)
385+
386+
372387
def test_random_mvnormal():
373388
rng = shared(np.random.RandomState(123))
374389

0 commit comments

Comments
 (0)