-
Notifications
You must be signed in to change notification settings - Fork 1
WIP implementation of multivariate normal distribution #52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
e02aab3
9616c0b
b1fade1
89c831a
94c1ad0
9ce92b3
afe96ba
8455725
522527a
4a73b8c
987e87d
2189d24
25400ed
b045a18
54750e1
63f49f5
49e764f
8787d17
8f22e4f
503fa14
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,7 +33,7 @@ | |
| from torch.distributions import Distribution | ||
| from torch.distributions import (Bernoulli, Beta, Binomial, Categorical, Cauchy, Chi2, | ||
| Dirichlet, Exponential, FisherSnedecor, Gamma, Geometric, | ||
| Gumbel, Laplace, Normal, OneHotCategorical, Multinomial, Pareto, | ||
| Gumbel, Laplace, MultivariateNormal, Normal, OneHotCategorical, Multinomial, Pareto, | ||
| StudentT, Uniform, kl_divergence) | ||
| from torch.distributions.dirichlet import _Dirichlet_backward | ||
| from torch.distributions.constraints import Constraint, is_dependent | ||
|
|
@@ -163,6 +163,20 @@ def pairwise(Dist, *params): | |
| 'scale': torch.Tensor([1e-5, 1e-5]), | ||
| }, | ||
| ]), | ||
| Example(MultivariateNormal, [ | ||
| { | ||
| 'loc': Variable(torch.randn(5, 2), requires_grad=True), | ||
| 'covariance_matrix': Variable(torch.Tensor([[2.0, 0.3],[0.3, 0.25]]), requires_grad=True), | ||
| }, | ||
| { | ||
| 'loc': Variable(torch.randn(2), requires_grad=True), | ||
| 'scale_tril': Variable(torch.Tensor([[2.0, 0.0],[-0.5, 0.25]]), requires_grad=True), | ||
| }, | ||
| { | ||
| 'loc': torch.Tensor([1.0, -1.0]), | ||
| 'covariance_matrix': torch.Tensor([[5.0, -0.5],[-0.5, 1.5]]), | ||
| }, | ||
| ]), | ||
| Example(Normal, [ | ||
| { | ||
| 'loc': Variable(torch.randn(5, 5), requires_grad=True), | ||
|
|
@@ -636,6 +650,66 @@ def test_normal_sample(self): | |
| scipy.stats.norm(loc=loc, scale=scale), | ||
| 'Normal(mean={}, std={})'.format(loc, scale)) | ||
|
|
||
| def test_multivariate_normal_shape(self): | ||
| mean = Variable(torch.randn(5, 3), requires_grad=True) | ||
| mean_no_batch = Variable(torch.randn(3), requires_grad=True) | ||
| mean_multi_batch = Variable(torch.randn(6, 5, 3), requires_grad=True) | ||
| tmp = torch.randn(3, 10) | ||
| cov = Variable(torch.matmul(tmp, tmp.t())/tmp.shape[-1], requires_grad=True) | ||
| scale_tril = Variable(torch.potrf(cov.data, upper=False), requires_grad=True) | ||
|
|
||
| # ensure that sample, batch, event shapes all handled correctly | ||
| self.assertEqual(MultivariateNormal(mean, cov).sample().size(), (5, 3)) | ||
| self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample().size(), (3,)) | ||
| self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample().size(), (6, 5, 3)) | ||
| self.assertEqual(MultivariateNormal(mean, cov).sample((2,)).size(), (2, 5, 3)) | ||
| self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2,)).size(), (2, 3)) | ||
| self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,)).size(), (2, 6, 5, 3)) | ||
| self.assertEqual(MultivariateNormal(mean, cov).sample((2,7)).size(), (2, 7, 5, 3)) | ||
| self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2,7)).size(), (2, 7, 3)) | ||
| self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,7)).size(), (2, 7, 6, 5, 3)) | ||
| self.assertEqual(MultivariateNormal(mean, scale_tril=scale_tril).sample((2,7)).size(), (2, 7, 5, 3)) | ||
|
|
||
| # check gradients | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice tests! |
||
| self._gradcheck_log_prob(MultivariateNormal, (mean, cov)) | ||
| self._gradcheck_log_prob(MultivariateNormal, (mean_multi_batch, cov)) | ||
| self._gradcheck_log_prob(MultivariateNormal, (mean, None, scale_tril)) | ||
|
|
||
| # check these also work for tensors, not just variables | ||
| mean = mean.data | ||
| mean_no_batch = mean_no_batch.data | ||
| mean_multi_batch = mean_multi_batch.data | ||
| cov = cov.data | ||
| self.assertEqual(MultivariateNormal(mean, cov).sample().size(), (5, 3)) | ||
| self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample().size(), (3,)) | ||
| self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample().size(), (6, 5, 3)) | ||
| self.assertEqual(MultivariateNormal(mean, cov).sample((2,)).size(), (2, 5, 3)) | ||
| self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2,)).size(), (2, 3)) | ||
| self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,)).size(), (2, 6, 5, 3)) | ||
| self.assertEqual(MultivariateNormal(mean, cov).sample((2,7)).size(), (2, 7, 5, 3)) | ||
| self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2,7)).size(), (2, 7, 3)) | ||
| self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,7)).size(), (2, 7, 6, 5, 3)) | ||
|
|
||
| @unittest.skipIf(not TEST_NUMPY, "Numpy not found") | ||
| def test_multivariate_normal_log_prob(self): | ||
|
|
||
| mean = Variable(torch.randn(3), requires_grad=True) | ||
| tmp = torch.randn(3, 10) | ||
| cov = Variable(torch.matmul(tmp, tmp.t())/tmp.shape[-1], requires_grad=True) | ||
| scale_tril = Variable(torch.potrf(cov.data, upper=False), requires_grad=True) | ||
|
|
||
| # check that logprob values match scipy logpdf, | ||
| # and that covariance and scale_tril parameters are equivalent | ||
| dist1 = MultivariateNormal(mean, cov) | ||
| dist2 = MultivariateNormal(mean, scale_tril=scale_tril) | ||
| ref_dist = scipy.stats.multivariate_normal(mean.data.numpy(), cov.data.numpy()) | ||
|
|
||
| x = dist1.sample((10,)) | ||
| expected = ref_dist.logpdf(x.data.numpy()) | ||
|
|
||
| self.assertAlmostEqual(0.0, np.mean((dist1.log_prob(x).data.numpy() - expected)**2), places=3) | ||
| self.assertAlmostEqual(0.0, np.mean((dist2.log_prob(x).data.numpy() - expected)**2), places=3) | ||
|
|
||
| def test_exponential(self): | ||
| rate = Variable(torch.randn(5, 5).abs(), requires_grad=True) | ||
| rate_1d = Variable(torch.randn(1).abs(), requires_grad=True) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| import math | ||
| from numbers import Number | ||
|
|
||
| import torch | ||
| from torch.autograd import Variable | ||
| from torch.distributions import constraints | ||
| from torch.distributions.distribution import Distribution | ||
| from torch.distributions.utils import broadcast_all | ||
|
|
||
| class MultivariateNormal(Distribution): | ||
| r""" | ||
| Creates a multivariate normal (also called Gaussian) distribution | ||
| parameterized by a mean vector and a covariance matrix. | ||
| The multivariate normal distribution can be parameterized either | ||
| in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}` | ||
| or a lower-triangular matrix :math:`\mathbf{L}` such that | ||
| :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top` as obtained via e.g. | ||
| Cholesky decomposition of the covariance. | ||
| Example: | ||
| >>> m = MultivariateNormal(torch.zeros(2), torch.eye(2)) | ||
| >>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I` | ||
| -0.2102 | ||
| -0.5429 | ||
| [torch.FloatTensor of size 2] | ||
| Args: | ||
| loc (Tensor or Variable): mean of the distribution | ||
| covariance_matrix (Tensor or Variable): covariance matrix (sigma positive-definite). | ||
| scale_tril (Tensor or Variable): lower-triangular factor of covariance. | ||
| Note: | ||
| Only one of `covariance_matrix` or `scale_tril` can be specified. | ||
| """ | ||
| params = {'loc': constraints.real, | ||
|
||
| 'covariance_matrix': constraints.positive_definite, | ||
| 'scale_tril': constraints.lower_triangular } | ||
| support = constraints.real | ||
| has_rsample = True | ||
|
|
||
| def __init__(self, loc, covariance_matrix=None, scale_tril=None): | ||
| batch_shape, event_shape = loc.shape[:-1], loc.shape[-1:] | ||
| if covariance_matrix is not None and scale_tril is not None: | ||
|
||
| raise ValueError("Either covariance matrix or scale_tril may be specified, not both.") | ||
| if covariance_matrix is None and scale_tril is None: | ||
| raise ValueError("One of either covariance matrix or scale_tril must be specified") | ||
| if scale_tril is None: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Neeraj made this cool decorator called class MultivariateNormal(Distribution):
def __init__(...):
...
if scale_tril is not None:
self.scale_tril = scale_tril
# leave .covariance_matrix unset
else:
self.covariance_matrix = covariance_matrix
# leave .scale_tril unset
...
@lazy_property
def scale_tril(self):
return torch.potrf(self.covariance_matrix, upper=False)
@lazy_property
def covariance_matrix(self):
return torch.mm(scale_tril, scale_tril.t()) |
||
| assert covariance_matrix.dim() >= 2 | ||
| if covariance_matrix.dim() > 2: | ||
| # TODO support batch_shape for covariance | ||
| raise NotImplementedError("batch_shape for covariance matrix is not yet supported") | ||
| else: | ||
| scale_tril = torch.potrf(covariance_matrix, upper=False) | ||
| else: | ||
| assert scale_tril.dim() >= 2 | ||
| if scale_tril.dim() > 2: | ||
| # TODO support batch_shape for scale_tril | ||
| raise NotImplementedError("batch_shape for scale_tril is not yet supported") | ||
| else: | ||
| covariance_matrix = torch.mm(scale_tril, scale_tril.t()) | ||
| self.loc = loc | ||
| self.covariance_matrix = covariance_matrix | ||
| self.scale_tril = scale_tril | ||
| super(MultivariateNormal, self).__init__(batch_shape, event_shape) | ||
|
|
||
| def rsample(self, sample_shape=torch.Size()): | ||
| shape = self._extended_shape(sample_shape) | ||
| eps = self.loc.new(*shape).normal_() | ||
| return self.loc + torch.matmul(eps, self.scale_tril.t()) | ||
|
|
||
| def log_prob(self, value): | ||
| self._validate_log_prob_arg(value) | ||
| delta = value - self.loc | ||
| # TODO replace torch.gesv with appropriate solver (e.g. potrs) | ||
| M = (delta * torch.gesv(delta.view(-1,delta.shape[-1]).t(), self.covariance_matrix)[0].t().view(delta.shape)).sum(-1) | ||
| log_det = torch.log(self.scale_tril.diag()).sum() | ||
| return -0.5*(M + self.loc.size(-1)*math.log(2*math.pi)) - log_det | ||
|
|
||
| def entropy(self): | ||
| # TODO this will need modified to support batched covariance | ||
| log_det = self.scale_tril.diag().log().sum(-1, keepdim=True) | ||
| H = (1.0 + (math.log(2*math.pi) + log_det))*0.5*self.loc.shape[-1] | ||
| if len(self._batch_shape) == 0: | ||
| return H | ||
| else: | ||
| return H.expand(self._batch_shape) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: underline is too short and will break docs. You can run
make -C docs htmland opendocs/build/html/index.htmlto check docs.