Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ Probability distributions - torch.distributions
.. autoclass:: LogNormal
:members:

:hidden:`MultivariateNormal`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MultivariateNormal
:members:

:hidden:`Normal`
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
109 changes: 108 additions & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from torch.distributions import (Bernoulli, Beta, Binomial, Categorical,
Cauchy, Chi2, Dirichlet, Distribution,
Exponential, FisherSnedecor, Gamma, Geometric,
Gumbel, Laplace, LogNormal, Multinomial,
Gumbel, Laplace, LogNormal, Multinomial, MultivariateNormal,
Normal, OneHotCategorical, Pareto, Poisson,
StudentT, Uniform, constraints, kl_divergence)
from torch.distributions.constraints import Constraint, is_dependent
Expand Down Expand Up @@ -221,6 +221,20 @@ def pairwise(Dist, *params):
'alpha': 1.0
}
]),
Example(MultivariateNormal, [
{
'loc': Variable(torch.randn(5, 2), requires_grad=True),
'covariance_matrix': Variable(torch.Tensor([[2.0, 0.3],[0.3, 0.25]]), requires_grad=True),
},
{
'loc': Variable(torch.randn(2), requires_grad=True),
'scale_tril': Variable(torch.Tensor([[2.0, 0.0],[-0.5, 0.25]]), requires_grad=True),
},
{
'loc': torch.Tensor([1.0, -1.0]),
'covariance_matrix': torch.Tensor([[5.0, -0.5],[-0.5, 1.5]]),
},
]),
Example(Poisson, [
{
'rate': Variable(torch.randn(5, 5).abs(), requires_grad=True),
Expand Down Expand Up @@ -784,6 +798,99 @@ def test_normal_sample(self):
scipy.stats.norm(loc=loc, scale=scale),
'Normal(mean={}, std={})'.format(loc, scale))

def test_multivariate_normal_shape(self):
mean = Variable(torch.randn(5, 3), requires_grad=True)
mean_no_batch = Variable(torch.randn(3), requires_grad=True)
mean_multi_batch = Variable(torch.randn(6, 5, 3), requires_grad=True)

# construct PSD covariance
tmp = torch.randn(3, 10)
cov = Variable(torch.matmul(tmp, tmp.t())/tmp.shape[-1], requires_grad=True)
scale_tril = Variable(torch.potrf(cov.data, upper=False), requires_grad=True)

# construct batch of PSD covariances
tmp = torch.randn(6, 5, 3, 10)
cov_batched = Variable((tmp.unsqueeze(-2)*tmp.unsqueeze(-3)).mean(-1), requires_grad=True)
scale_tril_batched = Variable(torch.stack([torch.potrf(C, upper=False) for C in cov_batched.data.view((-1,3,3))]).view(cov_batched.shape), requires_grad=True)

# ensure that sample, batch, event shapes all handled correctly
self.assertEqual(MultivariateNormal(mean, cov).sample().size(), (5, 3))
self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample().size(), (3,))
self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample().size(), (6, 5, 3))
self.assertEqual(MultivariateNormal(mean, cov).sample((2,)).size(), (2, 5, 3))
self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2,)).size(), (2, 3))
self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,)).size(), (2, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean, cov).sample((2,7)).size(), (2, 7, 5, 3))
self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2,7)).size(), (2, 7, 3))
self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,7)).size(), (2, 7, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean, cov_batched).sample((2,7)).size(), (2, 7, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean_no_batch, cov_batched).sample((2,7)).size(), (2, 7, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean_multi_batch, cov_batched).sample((2,7)).size(), (2, 7, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean, scale_tril=scale_tril).sample((2,7)).size(), (2, 7, 5, 3))
self.assertEqual(MultivariateNormal(mean, scale_tril=scale_tril_batched).sample((2,7)).size(), (2, 7, 6, 5, 3))

# check gradients
self._gradcheck_log_prob(MultivariateNormal, (mean, cov))
self._gradcheck_log_prob(MultivariateNormal, (mean_multi_batch, cov))
self._gradcheck_log_prob(MultivariateNormal, (mean_multi_batch, cov_batched))
self._gradcheck_log_prob(MultivariateNormal, (mean, None, scale_tril))
self._gradcheck_log_prob(MultivariateNormal, (mean_no_batch, None, scale_tril_batched))

# check these also work for tensors, not just variables
mean = mean.data
mean_no_batch = mean_no_batch.data
mean_multi_batch = mean_multi_batch.data
cov = cov.data
cov_batched = cov_batched.data
self.assertEqual(MultivariateNormal(mean, cov).sample().size(), (5, 3))
self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample().size(), (3,))
self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample().size(), (6, 5, 3))
self.assertEqual(MultivariateNormal(mean, cov).sample((2,)).size(), (2, 5, 3))
self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2,)).size(), (2, 3))
self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,)).size(), (2, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean, cov).sample((2,7)).size(), (2, 7, 5, 3))
self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2,7)).size(), (2, 7, 3))
self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,7)).size(), (2, 7, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean, cov_batched).sample((2,7)).size(), (2, 7, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean_no_batch, cov_batched).sample((2,7)).size(), (2, 7, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean_multi_batch, cov_batched).sample((2,7)).size(), (2, 7, 6, 5, 3))


@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_multivariate_normal_log_prob(self):
mean = Variable(torch.randn(3), requires_grad=True)
tmp = torch.randn(3, 10)
cov = Variable(torch.matmul(tmp, tmp.t())/tmp.shape[-1], requires_grad=True)
scale_tril = Variable(torch.potrf(cov.data, upper=False), requires_grad=True)

# check that logprob values match scipy logpdf,
# and that covariance and scale_tril parameters are equivalent
dist1 = MultivariateNormal(mean, cov)
dist2 = MultivariateNormal(mean, scale_tril=scale_tril)
ref_dist = scipy.stats.multivariate_normal(mean.data.numpy(), cov.data.numpy())

x = dist1.sample((10,))
expected = ref_dist.logpdf(x.data.numpy())

self.assertAlmostEqual(0.0, np.mean((dist1.log_prob(x).data.numpy() - expected)**2), places=3)
self.assertAlmostEqual(0.0, np.mean((dist2.log_prob(x).data.numpy() - expected)**2), places=3)

# Double-check that batched versions behave the same as unbatched
mean = Variable(torch.randn(5, 3), requires_grad=True)
tmp = torch.randn(5, 3, 10)
cov = Variable((tmp.unsqueeze(-2)*tmp.unsqueeze(-3)).mean(-1), requires_grad=True)

dist_batched = MultivariateNormal(mean, cov)
dist_unbatched = [MultivariateNormal(mean[i], cov[i]) for i in range(mean.size(0))]

x = dist_batched.sample((10,))
batched_prob = dist_batched.log_prob(x)
unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:,i]) for i in range(5)]).t()

self.assertEqual(batched_prob.shape, unbatched_prob.shape)
self.assertAlmostEqual(0.0, (batched_prob - unbatched_prob).abs().max(), places=3)


def test_exponential(self):
rate = Variable(torch.randn(5, 5).abs(), requires_grad=True)
rate_1d = Variable(torch.randn(1).abs(), requires_grad=True)
Expand Down
2 changes: 2 additions & 0 deletions torch/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from .laplace import Laplace
from .log_normal import LogNormal
from .multinomial import Multinomial
from .multivariate_normal import MultivariateNormal
from .normal import Normal
from .one_hot_categorical import OneHotCategorical
from .pareto import Pareto
Expand All @@ -73,6 +74,7 @@
'Laplace',
'LogNormal',
'Multinomial',
'MultivariateNormal',
'Normal',
'OneHotCategorical',
'Pareto',
Expand Down
32 changes: 30 additions & 2 deletions torch/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
'lower_triangular',
'nonnegative_integer',
'positive',
'positive_definite',
'positive_integer',
'real',
'real_vector',
'simplex',
'unit_interval',
]
Expand Down Expand Up @@ -166,21 +168,45 @@ class _LowerTriangular(Constraint):
Constrain to lower-triangular square matrices.
"""
def check(self, value):
return (torch.tril(value) == value).min(-1)[0].min(-1)[0]
masked_value = value*torch.tril(value.new(*value.shape[-2:]).fill_(1.0))
return (masked_value == value).min(-1)[0].min(-1)[0]


class _LowerCholesky(Constraint):
"""
Constrain to lower-triangular square matrices with positive diagonals.
"""
def check(self, value):
masked_value = value*torch.tril(value.new(*value.shape[-2:]).fill_(1.0))
lower_triangular = (masked_value == value).min(-1)[0].min(-1)[0]

n = value.size(-1)
diag_mask = torch.eye(n, n, out=value.new(n, n))
lower_triangular = (torch.tril(value) == value).min(-1)[0].min(-1)[0]
positive_diagonal = (value * diag_mask > (diag_mask - 1)).min(-1)[0].min(-1)[0]
return lower_triangular & positive_diagonal


class _PositiveDefinite(Constraint):
"""
Constrain to positive-definite matrices.
"""
def check(self, value):
matrix_shape = value.shape[-2:]
batch_shape = value.unsqueeze(0).shape[:-2]
# TODO: replace with batched linear algebra routine when one becomes available
# note that `symeig()` returns eigenvalues in ascending order
return torch.stack([v.symeig()[0][:1] > 0.0 for v in value.contiguous().view((-1,)+matrix_shape)]).view(batch_shape)


class _RealVector(Constraint):
"""
Constrain to real-valued vectors. This is the same as `constraints.real`,
but additionally reduces across the `event_shape` dimension.
"""
def check(self, value):
return (value == value).min(-1)[0]


# Public interface.
dependent = _Dependent()
dependent_property = _DependentProperty
Expand All @@ -189,6 +215,7 @@ def check(self, value):
positive_integer = _IntegerGreaterThan(1)
integer_interval = _IntegerInterval
real = _Real()
real_vector = _RealVector()
positive = _GreaterThan(0)
greater_than = _GreaterThan
less_than = _LessThan
Expand All @@ -197,3 +224,4 @@ def check(self, value):
simplex = _Simplex()
lower_triangular = _LowerTriangular()
lower_cholesky = _LowerCholesky()
positive_definite = _PositiveDefinite()
154 changes: 154 additions & 0 deletions torch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import math
from numbers import Number

import torch
from torch.autograd import Variable
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all, lazy_property


def _get_batch_shape(bmat, bvec):
"""
Given a batch of matrices and a batch of vectors, compute the combined `batch_shape`.
"""
try:
vec_shape = torch._C._infer_size(bvec.shape, bmat.shape[:-1])
except RuntimeError:
raise ValueError("Incompatible batch shapes: vector {}, matrix {}".format(bvec.shape, bmat.shape))
return torch.Size(vec_shape[:-1])


def _batch_mv(bmat, bvec):
"""
Performs a batched matrix-vector product, with an arbitrary batch shape.
"""
batch_shape = bvec.shape[:-1]
event_dim = bvec.shape[-1]
bmat = bmat.expand(batch_shape + (event_dim, event_dim))
if batch_shape != bmat.shape[:-2]:
raise ValueError("Batch shapes do not match: matrix {}, vector {}".format(bmat.shape, bvec.shape))
bvec = bvec.unsqueeze(-1)

# conform with `torch.bmm` interface, for matrices with `.dim() == 3`
if bvec.dim() == 2:
bvec.unsqueeze(0) #_
# flatten batch dimensions
bvec = bvec.contiguous().view((-1, event_dim, 1))
bmat = bmat.contiguous().view((-1, event_dim, event_dim)).expand((bvec.shape[0], -1, -1))
return torch.bmm(bmat, bvec).squeeze(-1).view(batch_shape+(event_dim,))


def _batch_potrf_lower(bmat):
"""
Applies a Cholesky decomposition to all matrices in a batch of arbitrary shape.
"""
n = bmat.size(-1)
cholesky = torch.stack([C.potrf(upper=False) for C in bmat.unsqueeze(0).contiguous().view((-1,n,n))])
return cholesky.view(bmat.shape)


def _batch_diag(bmat):
"""
Returns the diagonals of a batch of square matrices.
"""
n = bmat.size(-1)
dims = torch.arange(n, out=bmat.new(n)).long()
if isinstance(dims, Variable):
dims = dims.data # TODO: why can't I index with a Variable?
return bmat[...,dims,dims]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!



def _batch_mahalanobis(L, x):
r"""
Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.

Accepts batches for both L and x.
"""
# TODO: use `torch.potrs` or similar once a backwards pass is implemented.
flat_L = L.unsqueeze(0).contiguous().view((-1,)+L.shape[-2:])
L_inv = torch.stack([torch.inverse(Li.t()) for Li in flat_L]).view(L.shape)
return (x.unsqueeze(-1) * L_inv).sum(-2).pow(2.0).sum(-1)


class MultivariateNormal(Distribution):
r"""
Creates a multivariate normal (also called Gaussian) distribution
parameterized by a mean vector and a covariance matrix.

The multivariate normal distribution can be parameterized either
in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}`
or a lower-triangular matrix :math:`\mathbf{L}` such that
:math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top` as obtained via e.g.
Cholesky decomposition of the covariance.

Example:

>>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
>>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
-0.2102
-0.5429
[torch.FloatTensor of size 2]

Args:
loc (Tensor or Variable): mean of the distribution
covariance_matrix (Tensor or Variable): positive-definite covariance matrix
scale_tril (Tensor or Variable): lower-triangular factor of covariance

Note:
Only one of `covariance_matrix` or `scale_tril` can be specified.

"""
params = {'loc': constraints.real_vector,
'covariance_matrix': constraints.positive_definite,
'scale_tril': constraints.lower_cholesky }
support = constraints.real
has_rsample = True

def __init__(self, loc, covariance_matrix=None, scale_tril=None):
event_shape = torch.Size(loc.shape[-1:])
if (covariance_matrix is None) == (scale_tril is None):
raise ValueError("Exactly one of covariance_matrix or scale_tril may be specified (but not both).")
if scale_tril is None:
if covariance_matrix.dim() < 2:
raise ValueError("covariance_matrix must be two-dimensional")
self.covariance_matrix = covariance_matrix
batch_shape = _get_batch_shape(covariance_matrix, loc)
else:
if scale_tril.dim() < 2:
raise ValueError("scale_tril matrix must be two-dimensional")
self.scale_tril = scale_tril
batch_shape = _get_batch_shape(scale_tril, loc)
self.loc = loc
super(MultivariateNormal, self).__init__(batch_shape, event_shape)

@lazy_property
def scale_tril(self):
return _batch_potrf_lower(self.covariance_matrix)

@lazy_property
def covariance_matrix(self):
# To use torch.bmm, we first squash the batch_shape into a single dimension
flat_scale_tril = self.scale_tril.unsqueeze(0).contiguous().view((-1,)+self._event_shape*2)
return torch.bmm(flat_scale_tril, flat_scale_tril.transpose(-1,-2)).view(self.scale_tril.shape)

def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
eps = self.loc.new(*shape).normal_()
return self.loc + _batch_mv(self.scale_tril, eps)

def log_prob(self, value):
self._validate_log_prob_arg(value)
delta = value - self.loc
M = _batch_mahalanobis(self.scale_tril, delta)
log_det = _batch_diag(self.scale_tril).abs().log().sum(-1)
return -0.5*(M + self.loc.size(-1)*math.log(2*math.pi)) - log_det

def entropy(self):
log_det = _batch_diag(self.scale_tril).abs().log().sum(-1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, shouldn't this already have the correct shape? Why do you need to H.expand(self._batch_shape) below?

H = 0.5*(1.0 + math.log(2*math.pi))*self._event_shape[0] + log_det
if len(self._batch_shape) == 0:
return H
else:
return H.expand(self._batch_shape)