-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Refactor and simplify MultivariateNormal distribution #693
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
jpchen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
failing tests though: |
neerajprad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for moving this into torch.distributions!
|
@fritzo My GPRegression tutorial runs fine with this pull request. However, for a general case (when I use other likelihoods such as Bernoulli for GPClassification), Here is a simple script to replicate the problem Of course, the problem is more related to "how to deal with such RuntimeError", rather than this implementation of MultivariateNormal. |
|
I think that it is better to raise an issue to discuss, rather than comment here. |
This refactors and simplifies Pyro's
MultivariateNormaldistribution into a PyTorch-styleTorchMultivariateNormaldistribution plus a wrapper. This makes it easy to swap in a PyTorchMultivariateNormalas soon as it is available (as soon as probtorch/pytorch#52 is merged).This PR is needed to support @fehiepsi 's Gaussian Process tutorial #650 as Pyro migrates to PyTorch distributions.