Skip to content

Conversation

@fritzo
Copy link
Member

@fritzo fritzo commented Jan 19, 2018

This refactors and simplifies Pyro's MultivariateNormal distribution into a PyTorch-style TorchMultivariateNormal distribution plus a wrapper. This makes it easy to swap in a PyTorch MultivariateNormal as soon as it is available (as soon as probtorch/pytorch#52 is merged).

This PR is needed to support @fehiepsi 's Gaussian Process tutorial #650 as Pyro migrates to PyTorch distributions.

@fritzo fritzo added the WIP label Jan 19, 2018
@fritzo fritzo requested a review from neerajprad January 19, 2018 00:49
jpchen
jpchen previously approved these changes Jan 19, 2018
Copy link
Member

@jpchen jpchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jpchen
Copy link
Member

jpchen commented Jan 19, 2018

failing tests though:

E           RuntimeError: Lapack Error in potrf : the leading minor of order 2 is not positive definite at /remote/pytorch/aten/src/TH/generic/THTensorLapack.c:617

Copy link
Member

@neerajprad neerajprad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for moving this into torch.distributions!

@neerajprad neerajprad merged commit 9844a14 into dev Jan 23, 2018
@fehiepsi
Copy link
Member

fehiepsi commented Jan 23, 2018

@fritzo My GPRegression tutorial runs fine with this pull request. However, for a general case (when I use other likelihoods such as Bernoulli for GPClassification), torch.potrs will throw errors RuntimeError: Lapack Error in potrf : the leading minor of order ... is not positive definite at ....

Here is a simple script to replicate the problem

import torch
from torch.autograd import Variable

import pyro
import pyro.distributions as dist

K = Variable(torch.rand((1000, 1000))).double()
K = K @ K.t()

y = Variable(torch.zeros(1000)).double()

z = pyro.sample("z", dist.MultivariateNormal(y, K))

Of course, the problem is more related to "how to deal with such RuntimeError", rather than this implementation of MultivariateNormal.

@fehiepsi
Copy link
Member

fehiepsi commented Jan 23, 2018

I think that it is better to raise an issue to discuss, rather than comment here.

#696

@fritzo fritzo deleted the torch-mvn branch February 3, 2018 00:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants