-
Notifications
You must be signed in to change notification settings - Fork 1
WIP implementation of multivariate normal distribution #52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ed covariance matrices at all.
fritzo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Could you also add a section to docs/source/distributions.rst and maybe cd docs; make html to ensure docs still build (I've caught my own typos this way).
You can also take a look a the recent OneHotCategorical tests, since that is also a "multivariate" distribution with nontrivial event_shape.
test/test_distributions.py
Outdated
| from torch.autograd import Variable, gradcheck | ||
| from torch.distributions import (Bernoulli, Beta, Categorical, Dirichlet, | ||
| Exponential, Gamma, Laplace, Normal) | ||
| Exponential, Gamma, Laplace, Normal, MultivariateNormal) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add some example parameters in EXAMPLES below, ideally one that specifies cov and another that specifies scale_tril?
| def rsample(self, sample_shape=torch.Size()): | ||
| shape = self._extended_shape(sample_shape) | ||
| eps = self.mean.new(*shape).normal_() | ||
| return self.mean + torch.matmul(eps, self.scale_tril.t()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to support batched .rsample() by using torch.bmm() here instead of torch.matmul()? I'm not sure what's blocking batched covariance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That might work for distributions specified via scale_tril as opposed to covariance_matrix. The primary blocker is a batched torch.potrf. We also need a batched solver (batched torch.gesv or otherwise) to compute the log probability.
I think we can maybe do this by using torch.btrifact and torch.btrisolve instead of potrf and gesv. Haven't looked into it yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, I didn't realize that torch.btrifact doesn't actually support .backward() calls.
| Args: | ||
| mean (Tensor or Variable): mean of the distribution | ||
| cov(Tensor or Variable): covariance matrix (sigma positive-definite). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: It would be nice to stay maximally compatible with Tensorflow.distributions and name this covariance_matrix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I just left this matching the scipy mean and cov because they were so much shorter. Happy to change to loc and covariance_matrix if that is what we've settled on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'd like to keep the interfaces similar if possible, but I defer to your judgement here.
|
@tbrx How's this going? I might have time this weekend to try add some bits of our Pyro implementation into this branch. I think it's fine to provide a batching-complete interface even if we might need to do some python iteration under the hood for now, until pytorch#4612 merges. |
|
I haven't touched it since the last push — was waiting to see if a batched At the moment this actually works fine, with the caveat that a batch shape on the Adding a version of this which handles batches by using python loops or list comprehensions shouldn't be too difficult… |
|
Oh great, if it already works could we merge it now, then add full batched support in a follow-up PR? It would be nice to help motivate batched linear algebra work in PyTorch by claiming that "if xxx operation were batched then torch.distributions.MultivariateNormal would get batched covariance support for free". |
|
Okay — I think for that, all we need to do is update / expand the tests. If there are any other updates that happened to the Pyro version it would be nice to merge them in too. Maybe it is worth implementing a slow version with batched covariance matrices first just to the the API correct, though. If the batch size is reasonably small it shouldn't be too slow. |
|
@tbrx It would make @neerajprad 's and my job easier if you could merge this PR soon, simply adding tests and pushing further enhancements to follow-up PRs. The Pyro team has already migrated to PyTorch distributions and we're working around lack of |
|
That makes sense — the lack of a batch dimension on the covariance_matrix doesn't cause issues for you in Pyro if I understand correctly? I can update this PR and add the remaining tests Monday morning my time. |
|
What sort of constraint should we use here for the |
|
Yeah, I've been thinking about that. I think we should introduce new constraints:
Does that seem reasonable? They're simply symbolic placeholders, but we'll use them to register |
|
BTW I've added an issue #99 for implementing a |
fritzo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good so far!
| self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,7)).size(), (2, 7, 6, 5, 3)) | ||
| self.assertEqual(MultivariateNormal(mean, scale_tril=scale_tril).sample((2,7)).size(), (2, 7, 5, 3)) | ||
|
|
||
| # check gradients |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice tests!
| Only one of `covariance_matrix` or `scale_tril` can be specified. | ||
| """ | ||
| params = {'loc': constraints.real, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jwvdm noted that we could generically retrieve params for a distribution if we specified a canonical set of parameters. I've been trying to do this by putting only a single canonical parameterization in the .params dict (e.g. either loc,covariance_matrix or loc,scale_tril but not all three).
But I like what you've done here by adding them all. Maybe we should do that for all distributions and specify canonical_params or something in another field, or just let higher level libraries like Pyro or ProbTorch do that. WDYT?
|
|
||
| def __init__(self, loc, covariance_matrix=None, scale_tril=None): | ||
| batch_shape, event_shape = loc.shape[:-1], loc.shape[-1:] | ||
| if covariance_matrix is not None and scale_tril is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: You could simplify via
if (covariance_matrix is None) == (scale_tril is None):
raise ValueError(...)| raise ValueError("Either covariance matrix or scale_tril may be specified, not both.") | ||
| if covariance_matrix is None and scale_tril is None: | ||
| raise ValueError("One of either covariance matrix or scale_tril must be specified") | ||
| if scale_tril is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neeraj made this cool decorator called @lazy_property that could allow you to create scale_tril only if it does not exist, which would avoid unnecessary work in some cases. You could use it as follows:
class MultivariateNormal(Distribution):
def __init__(...):
...
if scale_tril is not None:
self.scale_tril = scale_tril
# leave .covariance_matrix unset
else:
self.covariance_matrix = covariance_matrix
# leave .scale_tril unset
...
@lazy_property
def scale_tril(self):
return torch.potrf(self.covariance_matrix, upper=False)
@lazy_property
def covariance_matrix(self):
return torch.mm(scale_tril, scale_tril.t())|
@tbrx In pytorch#4771 I've replaced |
|
@fritzo Actually… I like leaving it just as The main (potential) problem would be in computing the determinant of the covariance matrix. But that's actually fine. Here's an example lower-triangular matrix with a negative entry on the diagonal: We can use this to get a covariance matrix, whose Cholesky decomposition is of course different: The determinant of this covariance matrix is 0.25. We can get this from the Cholesky decomposition by but this is the same as That said, I believe the current MVN code actually handles It seems to me one nice use case of |
In my very limited experience, it is important to ensure positive definiteness rather than merely semidefiniteness (sorry if I've messed this up in u = Variable(torch.Tensor(4, 4).normal_(), requires_grad=True) # optimize this
scale_tril = u.tril(-1) + u.diag().exp().diag()If you merely instead define scale_tril = u.tril()then optimization will often pass though a hyperplane of singular matrices, i.e. where one of the I'm happy to add |
|
I'm not sure we should actually change it back! Just wanted to discuss. I actually agree with you that the PSD vs PD bit is probably more crucial. In that case I want to confirm I guess that I agree it is nice (generally) to have the And actually, your code snippet may have convinced me that this isn't a problem. The My one remaining concern though is what happens when we (eventually) update the MVN to support batching for |
There is no batch support yet, and def cholesky_lower_transform(x):
if x.dim() == 2:
return x.tril(-1) + x.diag().exp().diag()
else:
n = x.size(-1)
diag = torch.eye(n, out=x.new(n, n))
arange = torch.arange(n, out=x.new(n))
tril = (arange.unsqueeze(-1) > arange.unsqueeze(0)).float()
return x * tril + x.exp() * diagThis probably suffers from NAN issues, but the general idea should work. |
|
BTW The u = Variable(torch.Tensor(100, 100).normal_(), requires_grad=True)
scale_tril = to_constrained(constraints.cholesky_lower)(u)or even scale_tril = to_constrained(dist.params['scale_tril'])(u)I'm really looking forward to using this in Pyro 😄 |
…ay computation until after init
|
I believe the primary blocker to moving upstream at this point is the constraints, and a decision on the Alternatively, we could wait for pytorch#4771 and then include both this and |
|
I'd recommend adding an implementation of |
|
Is there anything I'm missing here (particularly in terms of test coverage…)? Otherwise, I'd be up for sending this PR upstream. |
fritzo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks ready to send upstream after one minor doc fix.
Re: testing, I think the strongest tests will be provided once we have a "by hand" bivariate normal distribution. We'll also be using this in Pyro right away; this should give us a little time to look for bugs and weird behavior before PyTorch release.
| :members: | ||
|
|
||
| :hidden:`MultivariateNormal` | ||
| ~~~~~~~~~~~~~~~~~~~~~~~ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: underline is too short and will break docs. You can run make -C docs html and open docs/build/html/index.html to check docs.
|
After merging in the latest Visually the results look "okay" for most entries — the max error reported is 0.349, which is on a value of 58.xxx. |
|
@tbrx That failure is not expected. Can you make sure you've rebuilt with |
|
So, it seems that the monte carlo test for Pareto entropy is just very sensitive… If I change the ordering of |
…ers for multivariate normal.
|
In working on the BivariateNormal #99 I started writing helpers for working with torch linear algebra functions, and realized that actually it would be hardly more work to port and implement these here. So, I updated this to support actual batching on covariance matrices and the Would appreciate feedback, particularly on whether I handled the "batch-friendly" matrix constraints correctly, and whether I am missing anything with the linear algebra helpers I added to Obviously the current implementation is not ideal, speed-wise:
|
fritzo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The helpers look reasonable, and I like that they abstract out the mess and make MultivariateNormal methods more readable.
I'd love to have this in master soon so we can "kick its tires" and get any fixes into PyTorch 0.4 release. E.g. it would help to have other multivariate distributions for testing batch shapes of Transforms.
|
|
||
|
|
||
| def _batch_mahalanobis(L, x): | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Use r""" rather than """ to open docstrings that contain backslashes
| dims = torch.arange(n, out=bmat.new(n)).long() | ||
| if isinstance(dims, Variable): | ||
| dims = dims.data # TODO: why can't I index with a Variable? | ||
| return bmat[...,dims,dims] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
| return -0.5*(M + self.loc.size(-1)*math.log(2*math.pi)) - log_det | ||
|
|
||
| def entropy(self): | ||
| log_det = _batch_diag(self.scale_tril).abs().log().sum(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, shouldn't this already have the correct shape? Why do you need to H.expand(self._batch_shape) below?
| raise ValueError("Batch shapes do not match: matrix {}, vector {}".format(bmat.shape, bvec.shape)) | ||
| bvec = bvec.unsqueeze(-1) | ||
|
|
||
| # using `torch.bmm` is surprisingly clunky... only works when `.dim() == 3` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before you send upstream, consider replacing with something more diplomatic 😉
conform to torch.bmm which requires .dim() == 3
|
Great, thanks @fritzo ! I'll (finally!) make a new pull request upstream. |
For issue #1. @fritzo
meanargument, plus (either)covorscale_triltorch.gesvfor computing log_prob; ifrequires_grad=Falsethen we could do a (cheaper)torch.potrs… probably worth using a solver-helper here like @dwd31415 has in the Pyro PR.Argument naming convention at the moment is:
meanandcovto matchscipy.stats.multivariate_normal, andscale_trilto match the Pyro PR.Test coverage is spotty at the moment (in particular I had some issue with the
_gradcheck_log_probhelper), but shapes seem okay and logprob values match scipy.One question is when we should compute the Cholesky decomposition if passed a
covargument instead ofscale_tril. I opted to call it initially up front in the constructor -- we're ultimately going to need it no matter what, either for sampling, or for computing the log determinant in the log_prob or entropy.