Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ Probability distributions - torch.distributions
.. autoclass:: Laplace
:members:

:hidden:`MultivariateNormal`
~~~~~~~~~~~~~~~~~~~~~~~
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: underline is too short and will break docs. You can run make -C docs html and open docs/build/html/index.html to check docs.


.. autoclass:: MultivariateNormal
:members:

:hidden:`Normal`
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
76 changes: 75 additions & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from torch.distributions import Distribution
from torch.distributions import (Bernoulli, Beta, Binomial, Categorical, Cauchy, Chi2,
Dirichlet, Exponential, FisherSnedecor, Gamma, Geometric,
Gumbel, Laplace, Normal, OneHotCategorical, Multinomial, Pareto,
Gumbel, Laplace, MultivariateNormal, Normal, OneHotCategorical, Multinomial, Pareto,
StudentT, Uniform, kl_divergence)
from torch.distributions.dirichlet import _Dirichlet_backward
from torch.distributions.constraints import Constraint, is_dependent
Expand Down Expand Up @@ -163,6 +163,20 @@ def pairwise(Dist, *params):
'scale': torch.Tensor([1e-5, 1e-5]),
},
]),
Example(MultivariateNormal, [
{
'loc': Variable(torch.randn(5, 2), requires_grad=True),
'covariance_matrix': Variable(torch.Tensor([[2.0, 0.3],[0.3, 0.25]]), requires_grad=True),
},
{
'loc': Variable(torch.randn(2), requires_grad=True),
'scale_tril': Variable(torch.Tensor([[2.0, 0.0],[-0.5, 0.25]]), requires_grad=True),
},
{
'loc': torch.Tensor([1.0, -1.0]),
'covariance_matrix': torch.Tensor([[5.0, -0.5],[-0.5, 1.5]]),
},
]),
Example(Normal, [
{
'loc': Variable(torch.randn(5, 5), requires_grad=True),
Expand Down Expand Up @@ -636,6 +650,66 @@ def test_normal_sample(self):
scipy.stats.norm(loc=loc, scale=scale),
'Normal(mean={}, std={})'.format(loc, scale))

def test_multivariate_normal_shape(self):
mean = Variable(torch.randn(5, 3), requires_grad=True)
mean_no_batch = Variable(torch.randn(3), requires_grad=True)
mean_multi_batch = Variable(torch.randn(6, 5, 3), requires_grad=True)
tmp = torch.randn(3, 10)
cov = Variable(torch.matmul(tmp, tmp.t())/tmp.shape[-1], requires_grad=True)
scale_tril = Variable(torch.potrf(cov.data, upper=False), requires_grad=True)

# ensure that sample, batch, event shapes all handled correctly
self.assertEqual(MultivariateNormal(mean, cov).sample().size(), (5, 3))
self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample().size(), (3,))
self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample().size(), (6, 5, 3))
self.assertEqual(MultivariateNormal(mean, cov).sample((2,)).size(), (2, 5, 3))
self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2,)).size(), (2, 3))
self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,)).size(), (2, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean, cov).sample((2,7)).size(), (2, 7, 5, 3))
self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2,7)).size(), (2, 7, 3))
self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,7)).size(), (2, 7, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean, scale_tril=scale_tril).sample((2,7)).size(), (2, 7, 5, 3))

# check gradients
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice tests!

self._gradcheck_log_prob(MultivariateNormal, (mean, cov))
self._gradcheck_log_prob(MultivariateNormal, (mean_multi_batch, cov))
self._gradcheck_log_prob(MultivariateNormal, (mean, None, scale_tril))

# check these also work for tensors, not just variables
mean = mean.data
mean_no_batch = mean_no_batch.data
mean_multi_batch = mean_multi_batch.data
cov = cov.data
self.assertEqual(MultivariateNormal(mean, cov).sample().size(), (5, 3))
self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample().size(), (3,))
self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample().size(), (6, 5, 3))
self.assertEqual(MultivariateNormal(mean, cov).sample((2,)).size(), (2, 5, 3))
self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2,)).size(), (2, 3))
self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,)).size(), (2, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean, cov).sample((2,7)).size(), (2, 7, 5, 3))
self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2,7)).size(), (2, 7, 3))
self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,7)).size(), (2, 7, 6, 5, 3))

@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_multivariate_normal_log_prob(self):

mean = Variable(torch.randn(3), requires_grad=True)
tmp = torch.randn(3, 10)
cov = Variable(torch.matmul(tmp, tmp.t())/tmp.shape[-1], requires_grad=True)
scale_tril = Variable(torch.potrf(cov.data, upper=False), requires_grad=True)

# check that logprob values match scipy logpdf,
# and that covariance and scale_tril parameters are equivalent
dist1 = MultivariateNormal(mean, cov)
dist2 = MultivariateNormal(mean, scale_tril=scale_tril)
ref_dist = scipy.stats.multivariate_normal(mean.data.numpy(), cov.data.numpy())

x = dist1.sample((10,))
expected = ref_dist.logpdf(x.data.numpy())

self.assertAlmostEqual(0.0, np.mean((dist1.log_prob(x).data.numpy() - expected)**2), places=3)
self.assertAlmostEqual(0.0, np.mean((dist2.log_prob(x).data.numpy() - expected)**2), places=3)

def test_exponential(self):
rate = Variable(torch.randn(5, 5).abs(), requires_grad=True)
rate_1d = Variable(torch.randn(1).abs(), requires_grad=True)
Expand Down
2 changes: 2 additions & 0 deletions torch/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .kl import kl_divergence, register_kl
from .laplace import Laplace
from .multinomial import Multinomial
from .multivariate_normal import MultivariateNormal
from .normal import Normal
from .one_hot_categorical import OneHotCategorical
from .pareto import Pareto
Expand All @@ -68,6 +69,7 @@
'Gumbel',
'Laplace',
'Multinomial',
'MultivariateNormal',
'Normal',
'OneHotCategorical',
'Pareto',
Expand Down
12 changes: 11 additions & 1 deletion torch/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
'lower_triangular',
'nonnegative_integer',
'positive',
'positive_definite',
'positive_integer',
'real',
'simplex',
Expand Down Expand Up @@ -165,7 +166,15 @@ class _LowerTriangular(Constraint):
Constrain to lower-triangular square matrices.
"""
def check(self, value):
return (torch.tril(value) == value).min(-1).min(-1)
return (torch.tril(value) == value).min(-1)[0].min(-1)[0]


class _PositiveDefinite(Constraint):
"""
Constrain to positive-definite matrices.
"""
def check(self, value):
return (torch.symeig(value)[0] > 0.0)


# Public interface.
Expand All @@ -183,3 +192,4 @@ def check(self, value):
interval = _Interval
simplex = _Simplex()
lower_triangular = _LowerTriangular()
positive_definite = _PositiveDefinite()
90 changes: 90 additions & 0 deletions torch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import math
from numbers import Number

import torch
from torch.autograd import Variable
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all

class MultivariateNormal(Distribution):
r"""
Creates a multivariate normal (also called Gaussian) distribution
parameterized by a mean vector and a covariance matrix.
The multivariate normal distribution can be parameterized either
in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}`
or a lower-triangular matrix :math:`\mathbf{L}` such that
:math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top` as obtained via e.g.
Cholesky decomposition of the covariance.
Example:
>>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
>>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
-0.2102
-0.5429
[torch.FloatTensor of size 2]
Args:
loc (Tensor or Variable): mean of the distribution
covariance_matrix (Tensor or Variable): covariance matrix (sigma positive-definite).
scale_tril (Tensor or Variable): lower-triangular factor of covariance.
Note:
Only one of `covariance_matrix` or `scale_tril` can be specified.
"""
params = {'loc': constraints.real,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jwvdm noted that we could generically retrieve params for a distribution if we specified a canonical set of parameters. I've been trying to do this by putting only a single canonical parameterization in the .params dict (e.g. either loc,covariance_matrix or loc,scale_tril but not all three).

But I like what you've done here by adding them all. Maybe we should do that for all distributions and specify canonical_params or something in another field, or just let higher level libraries like Pyro or ProbTorch do that. WDYT?

'covariance_matrix': constraints.positive_definite,
'scale_tril': constraints.lower_triangular }
support = constraints.real
has_rsample = True

def __init__(self, loc, covariance_matrix=None, scale_tril=None):
batch_shape, event_shape = loc.shape[:-1], loc.shape[-1:]
if covariance_matrix is not None and scale_tril is not None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You could simplify via

if (covariance_matrix is None) == (scale_tril is None):
    raise ValueError(...)

raise ValueError("Either covariance matrix or scale_tril may be specified, not both.")
if covariance_matrix is None and scale_tril is None:
raise ValueError("One of either covariance matrix or scale_tril must be specified")
if scale_tril is None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neeraj made this cool decorator called @lazy_property that could allow you to create scale_tril only if it does not exist, which would avoid unnecessary work in some cases. You could use it as follows:

class MultivariateNormal(Distribution):
    def __init__(...):
        ...
        if scale_tril is not None:
            self.scale_tril = scale_tril
            # leave .covariance_matrix unset
        else:
            self.covariance_matrix = covariance_matrix
            # leave .scale_tril unset
    ...
    @lazy_property
    def scale_tril(self):
        return torch.potrf(self.covariance_matrix, upper=False)
    @lazy_property
    def covariance_matrix(self):
        return torch.mm(scale_tril, scale_tril.t())

assert covariance_matrix.dim() >= 2
if covariance_matrix.dim() > 2:
# TODO support batch_shape for covariance
raise NotImplementedError("batch_shape for covariance matrix is not yet supported")
else:
scale_tril = torch.potrf(covariance_matrix, upper=False)
else:
assert scale_tril.dim() >= 2
if scale_tril.dim() > 2:
# TODO support batch_shape for scale_tril
raise NotImplementedError("batch_shape for scale_tril is not yet supported")
else:
covariance_matrix = torch.mm(scale_tril, scale_tril.t())
self.loc = loc
self.covariance_matrix = covariance_matrix
self.scale_tril = scale_tril
super(MultivariateNormal, self).__init__(batch_shape, event_shape)

def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
eps = self.loc.new(*shape).normal_()
return self.loc + torch.matmul(eps, self.scale_tril.t())

def log_prob(self, value):
self._validate_log_prob_arg(value)
delta = value - self.loc
# TODO replace torch.gesv with appropriate solver (e.g. potrs)
M = (delta * torch.gesv(delta.view(-1,delta.shape[-1]).t(), self.covariance_matrix)[0].t().view(delta.shape)).sum(-1)
log_det = torch.log(self.scale_tril.diag()).sum()
return -0.5*(M + self.loc.size(-1)*math.log(2*math.pi)) - log_det

def entropy(self):
# TODO this will need modified to support batched covariance
log_det = self.scale_tril.diag().log().sum(-1, keepdim=True)
H = (1.0 + (math.log(2*math.pi) + log_det))*0.5*self.loc.shape[-1]
if len(self._batch_shape) == 0:
return H
else:
return H.expand(self._batch_shape)