Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union

import torch
from botorch.exceptions.errors import BotorchTensorDimensionError
Expand Down Expand Up @@ -1087,7 +1087,7 @@ class InputPerturbation(InputTransform, Module):

def __init__(
self,
perturbation_set: Tensor,
perturbation_set: Union[Tensor, Callable[[Tensor], Tensor]],
bounds: Optional[Tensor] = None,
multiplicative: bool = False,
transform_on_train: bool = False,
Expand All @@ -1098,7 +1098,9 @@ def __init__(

Args:
perturbation_set: An `n_p x d`-dim tensor denoting the perturbations
to be added to the inputs.
to be added to the inputs. Alternatively, this can be a callable that
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the type of the arg then to be Union[Tensor, Callable[[Tensor], Tensor]?

returns `batch x n_p x d`-dim tensor of perturbations for input of
shape `batch x d`. This is useful for heteroscedastic perturbations.
bounds: A `2 x d`-dim tensor of lower and upper bounds for each
column of the input. If given, the perturbed inputs will be
clamped to these bounds.
Expand All @@ -1113,11 +1115,17 @@ def __init__(
transform when called from within a `fantasize` call. Default: False.
"""
super().__init__()
if perturbation_set.dim() != 2:
raise ValueError("`perturbation_set` must be an `n_p x d`-dim tensor!")
self.register_buffer("perturbation_set", perturbation_set)
if isinstance(perturbation_set, Tensor):
if perturbation_set.dim() != 2:
raise ValueError("`perturbation_set` must be an `n_p x d`-dim tensor!")
self.register_buffer("perturbation_set", perturbation_set)
else:
self.perturbation_set = perturbation_set
if bounds is not None:
if bounds.shape[-1] != perturbation_set.shape[-1]:
if (
isinstance(perturbation_set, Tensor)
and bounds.shape[-1] != perturbation_set.shape[-1]
):
raise ValueError(
"`bounds` must have the same number of columns (last dimension) as "
f"the `perturbation_set`! Got {bounds.shape[-1]} and "
Expand Down Expand Up @@ -1150,12 +1158,14 @@ def transform(self, X: Tensor) -> Tensor:
Returns:
A `batch_shape x (q * n_p) x d`-dim tensor of perturbed inputs.
"""
if isinstance(self.perturbation_set, Tensor):
perturbations = self.perturbation_set
else:
perturbations = self.perturbation_set(X)
expanded_X = X.unsqueeze(dim=-2).expand(
*X.shape[:-1], self.perturbation_set.shape[0], -1
)
expanded_perturbations = self.perturbation_set.expand(
*expanded_X.shape[:-1], -1
*X.shape[:-1], perturbations.shape[-2], -1
)
expanded_perturbations = perturbations.expand(*expanded_X.shape[:-1], -1)
if self.multiplicative:
perturbed_inputs = expanded_X * expanded_perturbations
else:
Expand Down
20 changes: 20 additions & 0 deletions test/models/transforms/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from botorch.models.utils import fantasize
from botorch.utils.testing import BotorchTestCase
from gpytorch.priors import LogNormalPrior
from torch import Tensor
from torch.distributions import Kumaraswamy
from torch.nn import Module

Expand Down Expand Up @@ -1004,3 +1005,22 @@ def test_input_perturbation(self):
dtype=dtype,
)
self.assertTrue(torch.allclose(transformed, expected))

# heteroscedastic
def perturbation_generator(X: Tensor) -> Tensor:
return torch.stack([X * 0.1, X * 0.2], dim=-2)

transform = InputPerturbation(
perturbation_set=perturbation_generator
).eval()
transformed = transform(X)
expected = torch.stack(
[
X[..., 0, :] * 1.1,
X[..., 0, :] * 1.2,
X[..., 1, :] * 1.1,
X[..., 1, :] * 1.2,
],
dim=-2,
)
self.assertTrue(torch.allclose(transformed, expected))