Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ jobs:
tests/distributions/test_continuous.py
tests/distributions/test_multivariate.py
tests/distributions/moments/test_means.py
tests/distributions/moments/test_variances.py

- |
tests/distributions/test_censored.py
Expand Down
3 changes: 2 additions & 1 deletion pymc/distributions/moments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
"""Moments dispatchers for pymc random variables."""

from pymc.distributions.moments.means import mean
from pymc.distributions.moments.variances import variance

__all__ = ["mean"]
__all__ = ["mean", "variance"]
24 changes: 15 additions & 9 deletions pymc/distributions/moments/means.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,10 @@
MatrixNormalRV,
MvStudentTRV,
StickBreakingWeightsRV,
WishartRV,
_LKJCholeskyCovRV,
)
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.distributions.shape_utils import maybe_resize, rv_size_is_none
from pymc.exceptions import UndefinedMomentException

__all__ = ["mean"]
Expand All @@ -100,12 +101,6 @@ def mean(rv: TensorVariable) -> TensorVariable:
return _mean(rv.owner.op, rv, *rv.owner.inputs)


def maybe_resize(a: TensorVariable, size) -> TensorVariable:
if not rv_size_is_none(size):
a = pt.full(size, a)
return a


@_mean.register(AsymmetricLaplaceRV)
def asymmetric_laplace_mean(op, rv, rng, size, b, kappa, mu):
return maybe_resize(mu - (kappa - 1 / kappa) / b, size)
Expand Down Expand Up @@ -250,7 +245,7 @@ def invgamma_mean(op, rv, rng, size, alpha, beta):


@_mean.register(KroneckerNormalRV)
def kronecker_normal_mean(op, rv, rng, size, mu, covs, chols, evds):
def kronecker_normal_mean(op, rv, rng, size, mu, sigma, *covs):
mean = mu
if not rv_size_is_none(size):
mean_size = pt.concatenate([size, mu.shape])
Expand Down Expand Up @@ -376,7 +371,9 @@ def polya_gamma_mean(op, rv, rng, size, h, z):


@_mean.register(RiceRV)
def rice_mean(op, rv, rng, size, nu, sigma):
def rice_mean(op, rv, rng, size, b, sigma):
# b is the shape parameter, nu = b * sigma is the noncentrality parameter
nu = b * sigma
nu_sigma_ratio = -(nu**2) / (2 * sigma**2)
return maybe_resize(
sigma
Expand Down Expand Up @@ -454,3 +451,12 @@ def wald_mean(op, rv, rng, size, mu, lam, alpha):
@_mean.register(WeibullBetaRV)
def weibull_mean(op, rv, rng, size, alpha, beta):
return maybe_resize(beta * pt.gamma(1 + 1 / alpha), size)


@_mean.register(WishartRV)
def wishart_mean(op, rv, rng, size, nu, V):
mean = nu * V
if not rv_size_is_none(size):
mean_size = pt.concatenate([size, V.shape[-2:]])
mean = pt.full(mean_size, mean)
return mean
Loading