Skip to content

Conversation

@tbrx
Copy link
Collaborator

@tbrx tbrx commented Dec 28, 2017

For issue #1. @fritzo

  • Doesn't support batched covariance matrices at all
  • Does support arbitrary batch sizes for the mean
  • Takes a mean argument, plus (either) cov or scale_tril
  • Uses torch.gesv for computing log_prob; if requires_grad=False then we could do a (cheaper) torch.potrs… probably worth using a solver-helper here like @dwd31415 has in the Pyro PR.

Argument naming convention at the moment is: mean and cov to match scipy.stats.multivariate_normal, and scale_tril to match the Pyro PR.

Test coverage is spotty at the moment (in particular I had some issue with the _gradcheck_log_prob helper), but shapes seem okay and logprob values match scipy.

One question is when we should compute the Cholesky decomposition if passed a cov argument instead of scale_tril. I opted to call it initially up front in the constructor -- we're ultimately going to need it no matter what, either for sampling, or for computing the log determinant in the log_prob or entropy.

Copy link

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Could you also add a section to docs/source/distributions.rst and maybe cd docs; make html to ensure docs still build (I've caught my own typos this way).

You can also take a look a the recent OneHotCategorical tests, since that is also a "multivariate" distribution with nontrivial event_shape.

from torch.autograd import Variable, gradcheck
from torch.distributions import (Bernoulli, Beta, Categorical, Dirichlet,
Exponential, Gamma, Laplace, Normal)
Exponential, Gamma, Laplace, Normal, MultivariateNormal)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add some example parameters in EXAMPLES below, ideally one that specifies cov and another that specifies scale_tril?

def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
eps = self.mean.new(*shape).normal_()
return self.mean + torch.matmul(eps, self.scale_tril.t())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to support batched .rsample() by using torch.bmm() here instead of torch.matmul()? I'm not sure what's blocking batched covariance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That might work for distributions specified via scale_tril as opposed to covariance_matrix. The primary blocker is a batched torch.potrf. We also need a batched solver (batched torch.gesv or otherwise) to compute the log probability.

I think we can maybe do this by using torch.btrifact and torch.btrisolve instead of potrf and gesv. Haven't looked into it yet.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, I didn't realize that torch.btrifact doesn't actually support .backward() calls.

Args:
mean (Tensor or Variable): mean of the distribution
cov(Tensor or Variable): covariance matrix (sigma positive-definite).
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It would be nice to stay maximally compatible with Tensorflow.distributions and name this covariance_matrix.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I just left this matching the scipy mean and cov because they were so much shorter. Happy to change to loc and covariance_matrix if that is what we've settled on.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'd like to keep the interfaces similar if possible, but I defer to your judgement here.

@fritzo
Copy link

fritzo commented Jan 18, 2018

@tbrx How's this going? I might have time this weekend to try add some bits of our Pyro implementation into this branch. I think it's fine to provide a batching-complete interface even if we might need to do some python iteration under the hood for now, until pytorch#4612 merges.

@tbrx
Copy link
Collaborator Author

tbrx commented Jan 19, 2018

I haven't touched it since the last push — was waiting to see if a batched torch.gesv or batched torch.potrf magically would make its way into upstream master :)

At the moment this actually works fine, with the caveat that a batch shape on the covariance_matrix or scale_tril parameters will throw NotImplementedError.

Adding a version of this which handles batches by using python loops or list comprehensions shouldn't be too difficult…

@fritzo
Copy link

fritzo commented Jan 19, 2018

Oh great, if it already works could we merge it now, then add full batched support in a follow-up PR? It would be nice to help motivate batched linear algebra work in PyTorch by claiming that "if xxx operation were batched then torch.distributions.MultivariateNormal would get batched covariance support for free".

@tbrx
Copy link
Collaborator Author

tbrx commented Jan 20, 2018

Okay — I think for that, all we need to do is update / expand the tests. If there are any other updates that happened to the Pyro version it would be nice to merge them in too.

Maybe it is worth implementing a slow version with batched covariance matrices first just to the the API correct, though. If the batch size is reasonably small it shouldn't be too slow.

@fritzo
Copy link

fritzo commented Jan 20, 2018

@tbrx It would make @neerajprad 's and my job easier if you could merge this PR soon, simply adding tests and pushing further enhancements to follow-up PRs.

The Pyro team has already migrated to PyTorch distributions and we're working around lack of MultivariateNormal in PyTorch master by building our own PyTorch--style TorchMultivariateNormal and wrapping it as we do with all other PyTorch distributions pyro-ppl/pyro#693 Until there is a MultivariateNormal in PyTorch master, any enhancements from our side will land in our fork of MultivariateNormal.

@tbrx
Copy link
Collaborator Author

tbrx commented Jan 20, 2018

That makes sense — the lack of a batch dimension on the covariance_matrix doesn't cause issues for you in Pyro if I understand correctly?

I can update this PR and add the remaining tests Monday morning my time.

@tbrx
Copy link
Collaborator Author

tbrx commented Jan 22, 2018

What sort of constraint should we use here for the loc vector and the support? The existing constraints.real trivially works, but possibly we would want a constraint which also takes into account the event_shape for vector- or matrix- valued distributions.

@fritzo
Copy link

fritzo commented Jan 22, 2018

Yeah, I've been thinking about that. I think we should introduce new constraints:

  • constraints.real_vector
  • constraints.positive_definite
  • constraints.cholesky_triu or constraints.cholesky_tril. We can get rid of constraints.lower_triangular since that fails to capture positivity of the diagonal, and therefore isn't really useful for anything (my mistake!).

Does that seem reasonable? They're simply symbolic placeholders, but we'll use them to register Transforms between constrained and unconstrained spaces.

@fritzo
Copy link

fritzo commented Jan 22, 2018

BTW I've added an issue #99 for implementing a BivariateNormal distribution as a reference implementation that does all linear algebra by hand (e.g. no torch.trtrs). We did this in Pyro and found some silent bugs in our MultivariateNormal gradients due to non-differentiability of some of PyTorch's linear algebra operations. I'd like to have some of these Bivariate-matches-Multivariate tests in PyTorch before release, to ensure we gradients are not being silently corrupted.

Copy link

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good so far!

self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,7)).size(), (2, 7, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean, scale_tril=scale_tril).sample((2,7)).size(), (2, 7, 5, 3))

# check gradients
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice tests!

Only one of `covariance_matrix` or `scale_tril` can be specified.
"""
params = {'loc': constraints.real,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jwvdm noted that we could generically retrieve params for a distribution if we specified a canonical set of parameters. I've been trying to do this by putting only a single canonical parameterization in the .params dict (e.g. either loc,covariance_matrix or loc,scale_tril but not all three).

But I like what you've done here by adding them all. Maybe we should do that for all distributions and specify canonical_params or something in another field, or just let higher level libraries like Pyro or ProbTorch do that. WDYT?


def __init__(self, loc, covariance_matrix=None, scale_tril=None):
batch_shape, event_shape = loc.shape[:-1], loc.shape[-1:]
if covariance_matrix is not None and scale_tril is not None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You could simplify via

if (covariance_matrix is None) == (scale_tril is None):
    raise ValueError(...)

raise ValueError("Either covariance matrix or scale_tril may be specified, not both.")
if covariance_matrix is None and scale_tril is None:
raise ValueError("One of either covariance matrix or scale_tril must be specified")
if scale_tril is None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neeraj made this cool decorator called @lazy_property that could allow you to create scale_tril only if it does not exist, which would avoid unnecessary work in some cases. You could use it as follows:

class MultivariateNormal(Distribution):
    def __init__(...):
        ...
        if scale_tril is not None:
            self.scale_tril = scale_tril
            # leave .covariance_matrix unset
        else:
            self.covariance_matrix = covariance_matrix
            # leave .scale_tril unset
    ...
    @lazy_property
    def scale_tril(self):
        return torch.potrf(self.covariance_matrix, upper=False)
    @lazy_property
    def covariance_matrix(self):
        return torch.mm(scale_tril, scale_tril.t())

@fritzo
Copy link

fritzo commented Jan 24, 2018

@tbrx In pytorch#4771 I've replaced constraints.lower_triangular with the more useful constraints.lower_cholesky that additionally enforces nonnegativity along the diagonal. I've also implemented a LowerCholeskyTransform for this optimizing parameters in this space. Let me know if I've misunderstood anything.

@tbrx
Copy link
Collaborator Author

tbrx commented Jan 24, 2018

@fritzo Actually… I like leaving it just as constraints.lower_triangular. While it is true that a scale_tril matrix generated by a Cholesky decomposition of a positive definite matrix would have positive entries on the diagonal, it's not required when specifying scale_tril. Any lower triangular matrix with nonzero entries along the diagonal should be enough for us.

The main (potential) problem would be in computing the determinant of the covariance matrix. But that's actually fine. Here's an example lower-triangular matrix with a negative entry on the diagonal:

> L = torch.Tensor([[ 1.0,  0.0, 0.0],
                    [-2.0, -1.0, 0.0],
                    [ 0.5,  0.5, 0.5]])

We can use this to get a covariance matrix, whose Cholesky decomposition is of course different:

> cov = torch.matmul(L, L.t())
> chol = torch.potrf(cov, upper=False)
> chol

 1.0000  0.0000  0.0000
-2.0000  1.0000  0.0000
 0.5000 -0.5000  0.5000
[torch.FloatTensor of size 3x3]

The determinant of this covariance matrix is 0.25. We can get this from the Cholesky decomposition by

> chol.diag().prod()**2

but this is the same as

> L.diag().prod()**2

That said, I believe the current MVN code actually handles scale_tril with negative diagonal entries incorrectly (yielding a nan); I will fix this.

It seems to me one nice use case of scale_tril (as opposed to covariance_matrix) for specifying a multivariate gaussian is that it does not require satisfying the sometimes-annoying PSD constraint, so it is easier to use the scale_tril parameterization directly on the output of some neural network layer. Requiring the diagonal to be positive makes that less natural.

@fritzo
Copy link

fritzo commented Jan 24, 2018

it is easier to use the scale_tril parameterization directly on the output of some neural network layer.

In my very limited experience, it is important to ensure positive definiteness rather than merely semidefiniteness (sorry if I've messed this up in constraints.cholesky_lower). This is easy to do by ensuring the diagonal entries are all strictly positive via

u = Variable(torch.Tensor(4, 4).normal_(), requires_grad=True)  # optimize this
scale_tril = u.tril(-1) + u.diag().exp().diag()

If you merely instead define

scale_tril = u.tril()

then optimization will often pass though a hyperplane of singular matrices, i.e. where one of the scale_tril diagonal entries is zero. This .exp() trick is the same approach you'd take when learning a univariate gaussian from a neural network, setting std = network_output.exp() rather than network_output.abs().

I'm happy to add constraints.lower_diagonal back in (EDIT done!), but I think in Pyro we'll favor the optimization-safe constraints.cholesky_lower (we can just wrap the PyTorch version to use a stricter constraint).

@tbrx
Copy link
Collaborator Author

tbrx commented Jan 24, 2018

I'm not sure we should actually change it back! Just wanted to discuss. I actually agree with you that the PSD vs PD bit is probably more crucial. In that case I want to confirm I guess that constraints.cholesky_lower would enforce positivity of the diagonal rather than just nonnegativity.

I agree it is nice (generally) to have the scale_tril parameter be something that could plausibly be the result of a Cholesky decomposition.

And actually, your code snippet may have convinced me that this isn't a problem. The u.diag().exp().diag() bit is less ugly than I thought it would be — I'd be up for (say) including your code snippet above in documentation or examples somewhere, demonstrating example usage of scale_tril.

My one remaining concern though is what happens when we (eventually) update the MVN to support batching for covariance_matrix and scale_tril. Is there a "batched" version of Tensor.diag()?

@fritzo
Copy link

fritzo commented Jan 24, 2018

Is there a "batched" version of Tensor.diag()?

There is no batch support yet, and CholeskyLowerTransform currently raises a NotImplementedError. However there are easy workarounds, for example this should work for batched matrices:

def cholesky_lower_transform(x):
    if x.dim() == 2:
        return x.tril(-1) + x.diag().exp().diag()
    else:
        n = x.size(-1)
        diag = torch.eye(n, out=x.new(n, n))
        arange = torch.arange(n, out=x.new(n))
        tril = (arange.unsqueeze(-1) > arange.unsqueeze(0)).float()
        return x * tril + x.exp() * diag

This probably suffers from NAN issues, but the general idea should work.

@fritzo
Copy link

fritzo commented Jan 24, 2018

BTW The u.diag().exp().diag() snippet is implemented in CholeskyLowerTransform and should soon be available via constraint registration as

u = Variable(torch.Tensor(100, 100).normal_(), requires_grad=True)
scale_tril = to_constrained(constraints.cholesky_lower)(u)

or even

scale_tril = to_constrained(dist.params['scale_tril'])(u)

I'm really looking forward to using this in Pyro 😄

@tbrx
Copy link
Collaborator Author

tbrx commented Jan 26, 2018

I believe the primary blocker to moving upstream at this point is the constraints, and a decision on the params attribute — should we wait on this until constraints.lower_cholesky is merged in as part of pytorch#4771? Or should I add implementations of _LowerCholesky (and maybe _RealVector) and assume we can merge later?

Alternatively, we could wait for pytorch#4771 and then include both this and BivariateNormal in the same pull request. It would be nice to have tests for this which reference BivariateNormal, and BivariateNormal will also need the same additional constraints.

@fritzo
Copy link

fritzo commented Jan 26, 2018

I'd recommend adding an implementation of _LowerCholesky and maybe _RealVector in this PR and mentioning in the PR description that this is waiting on pytorch#4771. Ideally we could send PRs serially, but I'm more worried about missing feature freeze deadline 😄

@tbrx
Copy link
Collaborator Author

tbrx commented Jan 27, 2018

Is there anything I'm missing here (particularly in terms of test coverage…)? Otherwise, I'd be up for sending this PR upstream.

Copy link

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks ready to send upstream after one minor doc fix.

Re: testing, I think the strongest tests will be provided once we have a "by hand" bivariate normal distribution. We'll also be using this in Pyro right away; this should give us a little time to look for bugs and weird behavior before PyTorch release.

:members:

:hidden:`MultivariateNormal`
~~~~~~~~~~~~~~~~~~~~~~~
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: underline is too short and will break docs. You can run make -C docs html and open docs/build/html/index.html to check docs.

@tbrx
Copy link
Collaborator Author

tbrx commented Jan 29, 2018

After merging in the latest master, I'm getting a test failure for the Pareto distribution, on test_entropy_monte_carlo, on example 2/3. Is this expected?

Visually the results look "okay" for most entries — the max error reported is 0.349, which is on a value of 58.xxx.

@fritzo
Copy link

fritzo commented Jan 29, 2018

@tbrx That failure is not expected. Can you make sure you've rebuilt with python setup.py build develop?

@tbrx
Copy link
Collaborator Author

tbrx commented Jan 29, 2018

So, it seems that the monte carlo test for Pareto entropy is just very sensitive…

If I change the ordering of EXAMPLES so that MultivariateNormal comes after Pareto, then the entropy test for Pareto runs using the same random number state as it currently uses on master --- in which case the test passes just fine.

@fritzo fritzo mentioned this pull request Jan 30, 2018
2 tasks
@tbrx
Copy link
Collaborator Author

tbrx commented Jan 30, 2018

In working on the BivariateNormal #99 I started writing helpers for working with torch linear algebra functions, and realized that actually it would be hardly more work to port and implement these here. So, I updated this to support actual batching on covariance matrices and the scale_tril parameter for MultivariateNormal #1 as well.

Would appreciate feedback, particularly on whether I handled the "batch-friendly" matrix constraints correctly, and whether I am missing anything with the linear algebra helpers I added to torch/distributions/multivariate_normal.py. I added additional tests (which pass), but may have missed something.

Obviously the current implementation is not ideal, speed-wise:

  • We don't have a batched version of torch.potrf, so this resorts to a python list comprehension;
  • We don't have a batched version of either torch.potrs or torch.trtrs, and furthermore, neither of those currently have backwards methods implemented. The workaround at the moment actually explicitly calls torch.inverse. If either torch.potrs or torch.trtrs receive a .backward implementation then we can plug them in and have something more efficient (but still with python list comprehensions).

Copy link

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The helpers look reasonable, and I like that they abstract out the mess and make MultivariateNormal methods more readable.

I'd love to have this in master soon so we can "kick its tires" and get any fixes into PyTorch 0.4 release. E.g. it would help to have other multivariate distributions for testing batch shapes of Transforms.



def _batch_mahalanobis(L, x):
"""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Use r""" rather than """ to open docstrings that contain backslashes

dims = torch.arange(n, out=bmat.new(n)).long()
if isinstance(dims, Variable):
dims = dims.data # TODO: why can't I index with a Variable?
return bmat[...,dims,dims]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

return -0.5*(M + self.loc.size(-1)*math.log(2*math.pi)) - log_det

def entropy(self):
log_det = _batch_diag(self.scale_tril).abs().log().sum(-1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, shouldn't this already have the correct shape? Why do you need to H.expand(self._batch_shape) below?

raise ValueError("Batch shapes do not match: matrix {}, vector {}".format(bmat.shape, bvec.shape))
bvec = bvec.unsqueeze(-1)

# using `torch.bmm` is surprisingly clunky... only works when `.dim() == 3`
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before you send upstream, consider replacing with something more diplomatic 😉

conform to torch.bmm which requires .dim() == 3

@tbrx
Copy link
Collaborator Author

tbrx commented Jan 30, 2018

Great, thanks @fritzo ! I'll (finally!) make a new pull request upstream.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants