Skip to content

Commit dfb3f78

Browse files
Fix kronecker_normal signature, fix rice mean, add wishhart mean
1 parent 471528d commit dfb3f78

2 files changed

Lines changed: 35 additions & 6 deletions

File tree

pymc/distributions/moments/means.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,10 @@
7878
MatrixNormalRV,
7979
MvStudentTRV,
8080
StickBreakingWeightsRV,
81+
WishartRV,
8182
_LKJCholeskyCovRV,
8283
)
83-
from pymc.distributions.shape_utils import rv_size_is_none, maybe_resize
84+
from pymc.distributions.shape_utils import maybe_resize, rv_size_is_none
8485
from pymc.exceptions import UndefinedMomentException
8586

8687
__all__ = ["mean"]
@@ -244,7 +245,7 @@ def invgamma_mean(op, rv, rng, size, alpha, beta):
244245

245246

246247
@_mean.register(KroneckerNormalRV)
247-
def kronecker_normal_mean(op, rv, rng, size, mu, covs, chols, evds):
248+
def kronecker_normal_mean(op, rv, rng, size, mu, sigma, *covs):
248249
mean = mu
249250
if not rv_size_is_none(size):
250251
mean_size = pt.concatenate([size, mu.shape])
@@ -370,7 +371,9 @@ def polya_gamma_mean(op, rv, rng, size, h, z):
370371

371372

372373
@_mean.register(RiceRV)
373-
def rice_mean(op, rv, rng, size, nu, sigma):
374+
def rice_mean(op, rv, rng, size, b, sigma):
375+
# b is the shape parameter, nu = b * sigma is the noncentrality parameter
376+
nu = b * sigma
374377
nu_sigma_ratio = -(nu**2) / (2 * sigma**2)
375378
return maybe_resize(
376379
sigma
@@ -448,3 +451,12 @@ def wald_mean(op, rv, rng, size, mu, lam, alpha):
448451
@_mean.register(WeibullBetaRV)
449452
def weibull_mean(op, rv, rng, size, alpha, beta):
450453
return maybe_resize(beta * pt.gamma(1 + 1 / alpha), size)
454+
455+
456+
@_mean.register(WishartRV)
457+
def wishart_mean(op, rv, rng, size, nu, V):
458+
mean = nu * V
459+
if not rv_size_is_none(size):
460+
mean_size = pt.concatenate([size, V.shape[-2:]])
461+
mean = pt.full(mean_size, mean)
462+
return mean

tests/distributions/moments/test_means.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
16+
1517
import numpy as np
1618
import pytest
1719

@@ -53,6 +55,7 @@
5355
uniform,
5456
vonmises,
5557
weibull_min,
58+
wishart,
5659
)
5760

5861
from pymc import (
@@ -168,9 +171,7 @@
168171
[Normal, norm, {"mu": 2, "sigma": 2}, {"loc": 2, "scale": 2}],
169172
[Pareto, pareto, {"alpha": 5, "m": 2}, {"b": 5, "scale": 2}],
170173
[Poisson, poisson, {"mu": 20}, {"mu": 20}],
171-
pytest.param(
172-
Rice, rice, {"b": 2, "sigma": 2}, {"b": 2, "scale": 2}, marks=pytest.mark.xfail
173-
), # Something is wrong with the Rice mean, maybe a Bessel function in pytensor?
174+
[Rice, rice, {"b": 2, "sigma": 2}, {"b": 2, "scale": 2}],
174175
[SkewNormal, skewnorm, {"mu": 2, "sigma": 2, "alpha": 2}, {"loc": 2, "scale": 2, "a": 2}],
175176
[
176177
SkewStudentT,
@@ -262,6 +263,22 @@ def test_mean_equal_expected(dist, dist_params, expected):
262263
)
263264

264265

266+
def test_wishart_mean():
267+
nu = 10
268+
V = np.array([[2.0, 0.5], [0.5, 1.5]])
269+
270+
with warnings.catch_warnings():
271+
warnings.simplefilter("ignore")
272+
from pymc import Wishart
273+
274+
rv = Wishart.dist(nu=nu, V=V)
275+
276+
pymc_mean_val = mean(rv).eval()
277+
scipy_mean = wishart(df=nu, scale=V).mean()
278+
279+
np.testing.assert_almost_equal(pymc_mean_val, scipy_mean)
280+
281+
265282
@pytest.mark.parametrize(
266283
["dist", "dist_params"],
267284
[

0 commit comments

Comments
 (0)