Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/pyro.distributions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Normal

MultivariateNormal
------------------
.. automodule:: pyro.distributions.multivariate_normal
.. automodule:: pyro.distributions.torch.multivariate_normal
:members:
:undoc-members:
:show-inheritance:
Expand Down
2 changes: 1 addition & 1 deletion pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from pyro.distributions.distribution import Distribution # noqa: F401
from pyro.distributions.half_cauchy import HalfCauchy
from pyro.distributions.log_normal import LogNormal
from pyro.distributions.multivariate_normal import MultivariateNormal
from pyro.distributions.poisson import Poisson
from pyro.distributions.random_primitive import RandomPrimitive

Expand All @@ -33,6 +32,7 @@
from pyro.distributions.torch.exponential import Exponential
from pyro.distributions.torch.gamma import Gamma
from pyro.distributions.torch.multinomial import Multinomial
from pyro.distributions.torch.multivariate_normal import MultivariateNormal
from pyro.distributions.torch.normal import Normal
from pyro.distributions.torch.one_hot_categorical import OneHotCategorical
from pyro.distributions.torch.uniform import Uniform
Expand Down
170 changes: 0 additions & 170 deletions pyro/distributions/multivariate_normal.py

This file was deleted.

81 changes: 81 additions & 0 deletions pyro/distributions/torch/multivariate_normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import absolute_import, division, print_function

import math

import torch
from torch.distributions import constraints
from torch.distributions.utils import lazy_property

from pyro.distributions.torch_wrapper import TorchDistribution
from pyro.distributions.util import broadcast_shape, copy_docs_from


def _matrix_inverse_compat(matrix, matrix_chol):
"""Computes the inverse of a positive semidefinite square matrix."""
if matrix.requires_grad:
# If derivatives are required, use the more expensive inverse.
return torch.inverse(matrix)
else:
# Use the cheaper Cholesky based potri.
return torch.potri(matrix_chol)


# TODO Move this upstream to PyTorch.
class TorchMultivariateNormal(torch.distributions.Distribution):
params = {"loc": constraints.real, "scale_tril": constraints.lower_triangular}
support = constraints.real
has_rsample = True

def __init__(self, loc, covariance_matrix, normalized=True):
self.loc = loc
self.covariance_matrix = covariance_matrix
batch_shape, event_shape = loc.shape[:-1], loc.shape[-1:]
super(TorchMultivariateNormal, self).__init__(batch_shape, event_shape)

@lazy_property
def scale_triu(self):
return torch.potrf(self.covariance_matrix)

def rsample(self, sample_shape=torch.Size()):
white = self.loc.new(sample_shape + self.loc.shape).normal_()
return self.loc + torch.matmul(white, self.scale_triu)

def log_prob(self, value):
delta = value - self.loc
sigma_inverse = _matrix_inverse_compat(self.covariance_matrix, self.scale_triu)
normalization_const = ((0.5 * self.event_shape[-1]) * math.log(2 * math.pi) +
self.scale_triu.diag().log().sum(-1))
mahalanobis_squared = (delta * torch.matmul(delta, sigma_inverse)).sum(-1)
return -(normalization_const + 0.5 * mahalanobis_squared)


@copy_docs_from(TorchDistribution)
class MultivariateNormal(TorchDistribution):
"""Multivariate normal (Gaussian) distribution.

A distribution over vectors in which all the elements have a joint Gaussian
density.

:param torch.autograd.Variable loc: Mean.
:param torch.autograd.Variable covariance_matrix: Covariance matrix.
Must be symmetric and positive semidefinite.
"""
reparameterized = True

def __init__(self, loc, covariance_matrix, *args, **kwargs):
torch_dist = TorchMultivariateNormal(loc, covariance_matrix)
x_shape = torch.Size(broadcast_shape(loc.shape, covariance_matrix.shape[:-1], strict=True))
event_dim = 1
super(MultivariateNormal, self).__init__(torch_dist, x_shape, event_dim, *args, **kwargs)

def batch_log_pdf(self, x):
batch_log_pdf = self.torch_dist.log_prob(x).view(self.batch_shape(x) + (1,))
if self.log_pdf_mask is not None:
batch_log_pdf = batch_log_pdf * self.log_pdf_mask
return batch_log_pdf

def analytic_mean(self):
return self.torch_dist.loc

def analytic_var(self):
return torch.diag(self.torch_dist.covariance_matrix)
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
[flake8]
max-line-length = 120
exclude = docs/src, examples/storyboard
exclude = docs/src

[isort]
line_length = 120
not_skip = __init__.py
known_first_party = pyro, tests
known_third_party = six, torch

[tool:pytest]
filterwarnings = error
Expand Down
6 changes: 2 additions & 4 deletions tests/distributions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,10 @@
scipy_dist=sp.multivariate_normal,
examples=[
{'loc': [2.0, 1.0], 'covariance_matrix': [[1.0, 0.5], [0.5, 1.0]],
'test_data': [[2.0, 1.0], [9.0, 3.4]]},
{'loc': [2.0, 1.0], 'scale_tril': [[1.0, 0.5], [0, 3900231685776981/4503599627370496]],
'test_data': [[2.0, 1.0], [9.0, 3.4]]}
'test_data': [[2.0, 1.0], [9.0, 3.4]]},
],
# This hack seems to be the best option right now, as 'sigma' is not handled well by get_scipy_batch_logpdf
scipy_arg_fn=lambda loc, covariance_matrix=None, scale_tril=None:
scipy_arg_fn=lambda loc, covariance_matrix=None, scale_triu=None:
((), {"mean": np.array(loc), "cov": np.array([[1.0, 0.5], [0.5, 1.0]])}),
prec=0.01,
min_samples=500000),
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_score_errors_non_broadcastable_data_shape(dist):
shape = d.shape(**dist_params)
non_broadcastable_shape = (shape[0] + 1,) + shape[1:]
test_data_non_broadcastable = ng_ones(non_broadcastable_shape)
with pytest.raises(ValueError):
with pytest.raises((ValueError, RuntimeError)):
d.batch_log_pdf(test_data_non_broadcastable, **dist_params)


Expand Down
43 changes: 0 additions & 43 deletions tests/distributions/test_multivariate_normal.py

This file was deleted.