Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a0c528e
init gp
fehiepsi Dec 14, 2017
a06908e
Merge remote-tracking branch 'upstream/dev' into add-gp
fehiepsi Dec 16, 2017
5946af8
add skeleton
fehiepsi Dec 16, 2017
eaf91fc
temp save
fehiepsi Dec 19, 2017
8d8d272
add gp tutorial
fehiepsi Dec 22, 2017
94aa221
add more text
fehiepsi Dec 26, 2017
df30ef1
show the fail case of inference, and the importance of constraint
fehiepsi Dec 27, 2017
cdfe529
move skeleton to contrib
fehiepsi Jan 14, 2018
c7c22cc
remove supporting for likelihoods and mean_functions
fehiepsi Jan 14, 2018
14e42c1
add rbf, gpr
fehiepsi Jan 15, 2018
08e8d2f
fix bugs and and tests
fehiepsi Jan 15, 2018
7573ab4
Merge remote-tracking branch 'upstream/dev' into add-gp
fehiepsi Jan 15, 2018
ddd91ec
lint
fehiepsi Jan 15, 2018
6b7e38b
Merge branch 'add-gp' into add-gp-tutorial
fehiepsi Jan 15, 2018
000aa1f
update tutorials
fehiepsi Jan 15, 2018
8619c4c
fix bugs
fehiepsi Jan 15, 2018
d02828d
review
fehiepsi Jan 16, 2018
bf13667
Merge branch 'add-gp' into add-gp-tutorial
fehiepsi Jan 16, 2018
d676a0b
use new instead of type_as
fehiepsi Jan 16, 2018
241391f
add params doc for gpr, remove unnecessary K
fehiepsi Jan 19, 2018
cca7069
remove .K(...)
fehiepsi Jan 19, 2018
6efd5bf
remove .K method in test, fix a bug in setting GPR noise
fehiepsi Jan 19, 2018
5f454e5
add input_dim parameter to kernel, replace gesv by potrf
fehiepsi Jan 20, 2018
dbd79c4
make lint
fehiepsi Jan 20, 2018
758bbf5
rerun tutorial to add input_dim to kernel
fehiepsi Jan 20, 2018
e3bcfe3
fix potrs not support grad
fehiepsi Jan 22, 2018
cdcd7bb
fix typo self.y
fehiepsi Jan 22, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/contrib.gp.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
.. automodule:: pyro.contrib.gp
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
1 change: 1 addition & 0 deletions docs/source/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ Contributed Code
:maxdepth: 2

contrib.named
contrib.gp
1 change: 1 addition & 0 deletions pyro/contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from __future__ import absolute_import, division, print_function
1 change: 1 addition & 0 deletions pyro/contrib/gp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from __future__ import absolute_import, division, print_function
5 changes: 5 additions & 0 deletions pyro/contrib/gp/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import absolute_import, division, print_function

from .rbf import RBF

# flake8: noqa
47 changes: 47 additions & 0 deletions pyro/contrib/gp/kernels/kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import absolute_import, division, print_function

import torch.nn as nn


class Kernel(nn.Module):
"""
Base class for kernels used in Gaussian Process.

Every inherited class should implement the forward pass which
take inputs X, X2 and return their covariance matrix.
"""

def __init__(self, active_dims=None, name=None):
super(Kernel, self).__init__()
self.active_dims = active_dims
self.name = name

def forward(self, X, Z=None):
"""
Calculate covariance matrix of inputs on active dimensionals.

:param torch.autograd.Variable X: A 2D tensor of size `N x input_dim`.
:param torch.autograd.Variable Z: A 2D tensor of size `N x input_dim`.
:return: Covariance matrix of X and Z with size `N x N`.
:rtype: torch.autograd.Variable
"""
raise NotImplementedError

def _slice_X(self, X):
"""
Slice X according to `self.active_dims`. If X is 1 dimensional then returns
a 2D tensor of size `N x 1`.

:param torch.autograd.Variable X: A 1D or 2D tensor.
:return: A 2D slice of X.
:rtype: torch.autograd.Variable
"""
if X.dim() == 2:
active_dims = self.active_dims
if active_dims is None:
active_dims = slice(X.size(1))
return X[:, active_dims]
elif X.dim() == 1:
return X.unsqueeze(1)
else:
raise ValueError("Input X must be either 1 or 2 dimensional.")
31 changes: 31 additions & 0 deletions pyro/contrib/gp/kernels/rbf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import absolute_import, division, print_function

import torch
from torch.nn import Parameter

from .kernel import Kernel


class RBF(Kernel):
"""
Implementation of Radial Basis Function kernel.
"""

def __init__(self, variance=torch.ones(1), lengthscale=torch.ones(1), active_dims=None, name="RBF"):
super(RBF, self).__init__(active_dims=active_dims, name=name)
self.variance = Parameter(variance)
self.lengthscale = Parameter(lengthscale)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add a description of how lengthscale works, typically you have one independent lengthscale per dimension, but you are assuming that lengthscales across all dimensions are the same (?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ysaatchi You are right (originally, I just wrote this module for 1-dimensional data). We need to refactor it a bit. I will solve it by introducing input_dim parameter (similar to GPy or GPflow) so we can set correct shape for initial lengthscale.


def forward(self, X, Z=None):
if Z is None:
Z = X
X = self._slice_X(X)
Z = self._slice_X(Z)
if X.size(1) != Z.size(1):
raise ValueError("Inputs must have the same number of features.")

X2 = (X ** 2).sum(1, keepdim=True)
Z2 = (Z ** 2).sum(1, keepdim=True)
XZ = X.matmul(Z.t())
d2 = X2 - 2 * XZ + Z2.t()
return self.variance * torch.exp(-0.5 * d2 / self.lengthscale**2)
5 changes: 5 additions & 0 deletions pyro/contrib/gp/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import absolute_import, division, print_function

from .gpr import GPRegression

# flake8: noqa
73 changes: 73 additions & 0 deletions pyro/contrib/gp/models/gpr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import absolute_import, division, print_function

from torch.autograd import Variable
import torch.nn as nn

import pyro
import pyro.distributions as dist


class GPRegression(nn.Module):
"""
Gaussian Process regression module.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add :param: descriptions for the inputs? Something like

Gaussian Process regression module.

:param torch.autograd.Variable X: A tensor of inputs.
:param torch.autograd.Variable y: A tensor of outputs for training.
:param pyro.gp.kernel.Kernel kernel: A Pyro kernel object.
:param torch.Tensor noise: An optional noise tensor.
:param dict priors: A mapping from ??? to priors.

:param torch.autograd.Variable X: A tensor of inputs.
:param torch.autograd.Variable y: A tensor of outputs for training.
:param pyro.gp.kernels.Kernel kernel: A Pyro kernel object.
:param torch.Tensor noise: An optional noise tensor.
:param dict priors: A mapping from kernel parameter's names to priors.
"""
def __init__(self, X, y, kernel, noise=None, priors=None):
super(GPRegression, self).__init__()
self.X = X
self.y = y
self.input_dim = X.size(0)
self.kernel = kernel
# TODO: define noise as a nn.Module, so we can train/set prior to it
Copy link

@ysaatchi ysaatchi Jan 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not the best idea, you should define noise as a likelihood with its own hypers and optimize it that way. In general we need to support arbitrary likelihoods for the GP so defining them at this early stage will be very helpful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree!

if noise is None:
self.noise = Variable(X.data.new([1]))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do? Needs a comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noise plays the role of Gaussian likelihood. I intend to add Likelihood, Mean function modules in a future pull request. So temporarily, I let it be constant. See #681 for a plan I have in mind. The pull request is served for the purpose of simplifying the original tutorial code.

else:
self.noise = Variable(noise)
self.priors = priors
if priors is None:
self.priors = {}

def model(self):
kernel_fn = pyro.random_module(self.kernel.name, self.kernel, self.priors)
kernel = kernel_fn()
K = kernel(self.X) + self.noise.repeat(self.input_dim).diag()
zero_loc = Variable(K.data.new([0]).expand(self.input_dim))
pyro.sample("f", dist.MultivariateNormal(zero_loc, K), obs=self.y)

def guide(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of the guide in this context? It seems like you are doing inference in forward(), so what is the point of the guide?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally, I want to put some constraints on parameters but it needs to write a wrapper of nn.Parameter (support transform methods). Then I found that setting priors and using guide for MAP inference might be a simpler idea. Do you find a better way of using the guide?

guide_priors = {}
for p in self.priors:
p_MAP_name = pyro.param_with_module_name(self.kernel.name, p) + "_MAP"
# init params by their prior means
p_MAP = pyro.param(p_MAP_name, Variable(self.priors[p].analytic_mean().data.clone(),
requires_grad=True))
guide_priors[p] = dist.Delta(p_MAP)
kernel_fn = pyro.random_module(self.kernel.name, self.kernel, guide_priors)
return kernel_fn()

def forward(self, Z):
"""
Compute the parameters of `p(y|Z) ~ N(loc, covariance_matrix)`
w.r.t. the new input Z.
:param torch.autograd.Variable Z: A 2D tensor.
:return: loc and covariance matrix of p(y|Z).
:rtype: torch.autograd.Variable and torch.autograd.Variable
"""
if Z.dim() == 2 and self.X.size(1) != Z.size(1):
assert ValueError("Train data and test data should have the same feature sizes.")
if Z.dim() == 1:
Z = Z.unsqueeze(1)
kernel = self.guide()
K = kernel(self.X) + self.noise.repeat(self.input_dim).diag()
K_xz = kernel(self.X, Z)
K_zx = K_xz.t()
K_zz = kernel(Z)
loc = K_zx.matmul(self.y.gesv(K)[0]).squeeze(1)
covariance_matrix = K_zz - K_zx.matmul(K_xz.gesv(K)[0])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very inefficient (calling gesv twice on an NxN matrix), see GPML book (Gaussian Processes for Machine Learning) for correct pseudocode for doing this.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of interest, does gesv play well with autograd -- quite cool if so :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ysaatchi Originally, I put noise as hyperparameter and use Cholesky decomposition. Then when noise is small, the Lapack error "the leading minor of order ... is not positive definite" annoyed me. In addition, somehow, I find torch.trtrs is not stable (pytorch/pytorch#4296). So I use gesv instead. Of course, these problems might come from some bugs in my code at that time.

Anyway, using gesv might be not a good way, so I will use Cholesky decomposition again.

p/s: gesv supports autograd, but not supports batch yet. :)

return loc, covariance_matrix
1 change: 1 addition & 0 deletions tests/contrib/gp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from __future__ import absolute_import, division, print_function
18 changes: 18 additions & 0 deletions tests/contrib/gp/test_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import absolute_import, division, print_function

import torch
from torch.autograd import Variable

from pyro.contrib.gp.kernels import RBF
from tests.common import assert_equal


def test_forward_rbf():
kernel = RBF(variance=torch.Tensor([2]), lengthscale=torch.Tensor([2]))
X = Variable(torch.Tensor([[1, 0, 1], [2, 1, 3]]))
Z = Variable(torch.Tensor([[4, 5, 6], [3, 1, 7]]))
K = kernel(X, Z)
assert K.dim() == 2
assert K.size(0) == 2
assert K.size(1) == 2
assert_equal(K.data.sum(), 0.30531)
24 changes: 24 additions & 0 deletions tests/contrib/gp/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import absolute_import, division, print_function

import torch
from torch.autograd import Variable

from pyro.contrib.gp.kernels import RBF
from pyro.contrib.gp.models import GPRegression
from tests.common import assert_equal


def test_forward_gpr():
kernel = RBF(torch.ones(1), torch.ones(1))
X = Variable(torch.Tensor([[1, 2, 3], [4, 5, 6]]))
y = Variable(torch.Tensor([0, 1]))
gpr = GPRegression(X, y, kernel, noise=torch.zeros(1))
Z = X
loc, cov = gpr(Z)
assert loc.dim() == 1
assert cov.dim() == 2
assert loc.size(0) == 2
assert cov.size(0) == 2
assert cov.size(1) == 2
assert_equal(loc.data.sum(), kernel(X).matmul(y).data.sum())
assert_equal(cov.data.abs().sum(), 0)
Loading