diff --git a/evidential_deep_learning/__init__.py b/evidential_deep_learning/__init__.py index 9c5a86c..6c2aefc 100644 --- a/evidential_deep_learning/__init__.py +++ b/evidential_deep_learning/__init__.py @@ -1,2 +1,58 @@ -from . import layers -from . import losses +# TODO: This is pretty hacky namespace manipulation but it works +import sys + +self = sys.modules[__name__] + +default_backend = 'tf' + +self.torch_avail = False +try: + import torch + + self.torch_avail = True + self.backend = 'torch' +except ImportError: + pass + +self.tf_avail = False +try: + import tensorflow as tf + + self.tf_avail = True + self.backend = 'tf' +except ImportError: + pass + +if not (self.torch_avail or self.tf_avail): + raise ImportError("Must have either PyTorch or Tensorflow available") + +if self.torch_avail and self.tf_avail: + self.backend = default_backend + + +def set_backend(backend): + if backend == 'tf': + if not self.tf_avail: + raise ImportError(f"Cannot use backend 'tf' if tensorflow is not installed") + from .tf import layers as layers + from .tf import losses as losses + self.layers = layers + self.losses = losses + elif backend == 'torch': + if not self.torch_avail: + raise ImportError(f"Cannot use backend 'torch' if pytorch is not installed") + from .pytorch import layers as layers + from .pytorch import losses as losses + self.layers = layers + self.losses = losses + else: + raise ValueError(f"Invalid choice of backend: {backend}, options are 'tf' or 'torch'") + + +def get_backend(): + return self.backend + + +self.get_backend = get_backend +self.set_backend = set_backend +self.set_backend(self.backend) diff --git a/evidential_deep_learning/pytorch/__init__.py b/evidential_deep_learning/pytorch/__init__.py new file mode 100644 index 0000000..1a2f7a6 --- /dev/null +++ b/evidential_deep_learning/pytorch/__init__.py @@ -0,0 +1 @@ +from . import losses, layers diff --git a/evidential_deep_learning/pytorch/layers/__init__.py b/evidential_deep_learning/pytorch/layers/__init__.py new file mode 100644 index 0000000..cb3bb70 --- /dev/null +++ b/evidential_deep_learning/pytorch/layers/__init__.py @@ -0,0 +1,2 @@ +from .dense import * +from .conv2d import * \ No newline at end of file diff --git a/evidential_deep_learning/pytorch/layers/conv2d.py b/evidential_deep_learning/pytorch/layers/conv2d.py new file mode 100644 index 0000000..af7f2f1 --- /dev/null +++ b/evidential_deep_learning/pytorch/layers/conv2d.py @@ -0,0 +1,48 @@ +import torch +from torch.nn import Module, Conv2d +import torch.nn.functional as F + + +# TODO: efficiently handle batch dimension + + +class Conv2DNormal(Module): + def __init__(self, in_channels, out_tasks, kernel_size, **kwargs): + super(Conv2DNormal, self).__init__() + self.in_channels = in_channels + self.out_channels = 2 * out_tasks + self.n_tasks = out_tasks + self.conv = Conv2d(self.in_channels, self.out_channels, kernel_size, **kwargs) + + def forward(self, x): + output = self.conv(x) + if len(x.shape) == 3: + mu, logsigma = torch.split(output, self.n_tasks, dim=0) + else: + mu, logsigma = torch.split(output, self.n_tasks, dim=1) + + sigma = F.softplus(logsigma) + 1e-6 + + return torch.stack([mu, sigma]).to(x.device) + + +class Conv2DNormalGamma(Module): + def __init__(self, in_channels, out_tasks, kernel_size, **kwargs): + super(Conv2DNormalGamma, self).__init__() + self.in_channels = in_channels + self.out_channels = out_tasks + self.conv = Conv2d(in_channels, 4 * out_tasks, kernel_size, **kwargs) + + def forward(self, x): + output = self.conv(x) + + if len(x.shape) == 3: + gamma, lognu, logalpha, logbeta = torch.split(output, self.out_channels, dim=0) + else: + gamma, lognu, logalpha, logbeta = torch.split(output, self.out_channels, dim=1) + + nu = F.softplus(lognu) + alpha = F.softplus(logalpha) + 1. + beta = F.softplus(logbeta) + return torch.stack([gamma, nu, alpha, beta]).to(x.device) + diff --git a/evidential_deep_learning/pytorch/layers/dense.py b/evidential_deep_learning/pytorch/layers/dense.py new file mode 100644 index 0000000..d102dc3 --- /dev/null +++ b/evidential_deep_learning/pytorch/layers/dense.py @@ -0,0 +1,47 @@ +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import Module + +# TODO: Find a way to efficiently handle batch dimension + + +class DenseNormal(Module): + def __init__(self, n_input, n_out_tasks=1): + super(DenseNormal, self).__init__() + self.n_in = n_input + self.n_out = 2 * n_out_tasks + self.n_tasks = n_out_tasks + self.l1 = nn.Linear(self.n_in, self.n_out) + + def forward(self, x): + x = self.l1(x) + if len(x.shape) == 1: + mu, logsigma = torch.split(x, self.n_tasks, dim=0) + else: + mu, logsigma = torch.split(x, self.n_tasks, dim=1) + + sigma = F.softplus(logsigma) + 1e-6 + return torch.stack(mu, sigma).to(x.device) + + +class DenseNormalGamma(Module): + def __init__(self, n_input, n_out_tasks=1): + super(DenseNormalGamma, self).__init__() + self.n_in = n_input + self.n_out = 4 * n_out_tasks + self.n_tasks = n_out_tasks + self.l1 = nn.Linear(self.n_in, self.n_out) + + def forward(self, x): + x = self.l1(x) + if len(x.shape) == 1: + gamma, lognu, logalpha, logbeta = torch.split(x, self.n_tasks, dim=0) + else: + gamma, lognu, logalpha, logbeta = torch.split(x, self.n_tasks, dim=1) + + nu = F.softplus(lognu) + alpha = F.softplus(logalpha) + 1. + beta = F.softplus(logbeta) + + return torch.stack([gamma, nu, alpha, beta]).to(x.device) diff --git a/evidential_deep_learning/pytorch/losses/__init__.py b/evidential_deep_learning/pytorch/losses/__init__.py new file mode 100644 index 0000000..e0ac92d --- /dev/null +++ b/evidential_deep_learning/pytorch/losses/__init__.py @@ -0,0 +1,2 @@ +from .continous import * +from .discrete import * \ No newline at end of file diff --git a/evidential_deep_learning/pytorch/losses/continous.py b/evidential_deep_learning/pytorch/losses/continous.py new file mode 100644 index 0000000..262e0f8 --- /dev/null +++ b/evidential_deep_learning/pytorch/losses/continous.py @@ -0,0 +1,57 @@ +import torch +from torch.distributions import Normal +from torch import nn +import numpy as np + +MSE = nn.MSELoss(reduction='mean') + + +def reduce(val, reduction): + if reduction == 'mean': + val = val.mean() + elif reduction == 'sum': + val = val.sum() + elif reduction == 'none': + pass + else: + raise ValueError(f"Invalid reduction argument: {reduction}") + return val + + +def RMSE(y, y_): + return MSE(y, y_).sqrt() + + +def Gaussian_NLL(y, mu, sigma, reduction='mean'): + dist = Normal(loc=mu, scale=sigma) + # TODO: refactor to mirror TF implementation due to numerical instability + logprob = -1. * dist.log_prob(y) + return reduce(logprob, reduction=reduction) + + +def NIG_NLL(y: torch.Tensor, + gamma: torch.Tensor, + nu: torch.Tensor, + alpha: torch.Tensor, + beta: torch.Tensor, reduction='mean'): + inter = 2 * beta * (1 + nu) + + nll = 0.5 * (np.pi / nu).log() \ + - alpha * inter.log() \ + + (alpha + 0.5) * (nu * (y - gamma) ** 2 + inter).log() \ + + torch.lgamma(alpha) \ + - torch.lgamma(alpha + 0.5) + return reduce(nll, reduction=reduction) + + +def NIG_Reg(y, gamma, nu, alpha, reduction='mean'): + error = (y - gamma).abs() + evidence = 2. * nu + alpha + return reduce(error * evidence, reduction=reduction) + + +def EvidentialRegression(y: torch.Tensor, evidential_output: torch.Tensor, lmbda=1.): + gamma, nu, alpha, beta = evidential_output + loss_nll = NIG_NLL(y, gamma, nu, alpha, beta) + loss_reg = NIG_Reg(y, gamma, nu, alpha) + return loss_nll, lmbda * loss_reg diff --git a/evidential_deep_learning/pytorch/losses/discrete.py b/evidential_deep_learning/pytorch/losses/discrete.py new file mode 100644 index 0000000..1044457 --- /dev/null +++ b/evidential_deep_learning/pytorch/losses/discrete.py @@ -0,0 +1,103 @@ +import torch +import torch.nn.functional as F + +BCELoss = torch.nn.BCEWithLogitsLoss() + + +def Dirichlet_SOS(y, outputs, device=None): + return edl_log_loss(outputs, y, device=device if device else outputs.device) + + +def Dirichlet_Evidence(outputs): + """Calculate ReLU evidence""" + return relu_evidence(outputs) + + +def Dirichlet_Matches(predictions, labels): + """Calculate the number of matches from index predictions""" + assert predictions.shape == labels.shape, f"Dimension mismatch between predictions " \ + f"({predictions.shape}) and labels ({labels.shape})" + return torch.reshape(torch.eq(predictions, labels).float(), (-1, 1)) + + +def Dirichlet_Predictions(outputs): + """Calculate predictions from logits""" + return torch.argmax(outputs, dim=1) + + +def Dirichlet_Uncertainty(outputs): + """Calculate uncertainty from logits""" + alpha = relu_evidence(outputs) + 1 + return alpha.size(1) / torch.sum(alpha, dim=1, keepdim=True) + + +def Sigmoid_CE(y, y_logits, device=None): + return BCELoss(y_logits, y, device=device if device else y_logits.device) + + +# MIT License +# +# Copyright (c) 2019 Douglas Brion +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +def relu_evidence(y): + return F.relu(y) + + +def exp_evidence(y): + return torch.exp(torch.clamp(y, -10, 10)) + + +def softplus_evidence(y): + return F.softplus(y) + + +def kl_divergence(alpha, num_classes, device=None): + beta = torch.ones([1, num_classes], dtype=torch.float32, device=device) + S_alpha = torch.sum(alpha, dim=1, keepdim=True) + S_beta = torch.sum(beta, dim=1, keepdim=True) + lnB = torch.lgamma(S_alpha) - \ + torch.sum(torch.lgamma(alpha), dim=1, keepdim=True) + lnB_uni = torch.sum(torch.lgamma(beta), dim=1, + keepdim=True) - torch.lgamma(S_beta) + + dg0 = torch.digamma(S_alpha) + dg1 = torch.digamma(alpha) + + kl = torch.sum((alpha - beta) * (dg1 - dg0), dim=1, + keepdim=True) + lnB + lnB_uni + return kl + + +def edl_loss(func, y, alpha, device=None): + S = torch.sum(alpha, dim=1, keepdim=True) + A = torch.sum(y * (func(S) - func(alpha)), dim=1, keepdim=True) + + kl_alpha = (alpha - 1) * (1 - y) + 1 + kl_div = kl_divergence(kl_alpha, y.shape[1], device=device) + return A + kl_div + + +def edl_log_loss(output, target, device=None): + evidence = relu_evidence(output) + alpha = evidence + 1 + loss = torch.mean(edl_loss(torch.log, target, alpha, device=device)) + assert loss is not None + return loss diff --git a/evidential_deep_learning/tf/__init__.py b/evidential_deep_learning/tf/__init__.py new file mode 100644 index 0000000..1a2f7a6 --- /dev/null +++ b/evidential_deep_learning/tf/__init__.py @@ -0,0 +1 @@ +from . import losses, layers diff --git a/evidential_deep_learning/layers/__init__.py b/evidential_deep_learning/tf/layers/__init__.py similarity index 100% rename from evidential_deep_learning/layers/__init__.py rename to evidential_deep_learning/tf/layers/__init__.py diff --git a/evidential_deep_learning/layers/conv2d.py b/evidential_deep_learning/tf/layers/conv2d.py similarity index 100% rename from evidential_deep_learning/layers/conv2d.py rename to evidential_deep_learning/tf/layers/conv2d.py diff --git a/evidential_deep_learning/layers/dense.py b/evidential_deep_learning/tf/layers/dense.py similarity index 100% rename from evidential_deep_learning/layers/dense.py rename to evidential_deep_learning/tf/layers/dense.py diff --git a/evidential_deep_learning/losses/__init__.py b/evidential_deep_learning/tf/losses/__init__.py similarity index 100% rename from evidential_deep_learning/losses/__init__.py rename to evidential_deep_learning/tf/losses/__init__.py diff --git a/evidential_deep_learning/losses/continuous.py b/evidential_deep_learning/tf/losses/continuous.py similarity index 100% rename from evidential_deep_learning/losses/continuous.py rename to evidential_deep_learning/tf/losses/continuous.py diff --git a/evidential_deep_learning/losses/discrete.py b/evidential_deep_learning/tf/losses/discrete.py similarity index 100% rename from evidential_deep_learning/losses/discrete.py rename to evidential_deep_learning/tf/losses/discrete.py diff --git a/hello_world.py b/hello_world.py deleted file mode 100644 index 9181bf6..0000000 --- a/hello_world.py +++ /dev/null @@ -1,77 +0,0 @@ -import functools -import numpy as np -import matplotlib.pyplot as plt - -import evidential_deep_learning as edl -import tensorflow as tf - - -def main(): - # Create some training and testing data - x_train, y_train = my_data(-4, 4, 1000) - x_test, y_test = my_data(-7, 7, 1000, train=False) - - # Define our model with an evidential output - model = tf.keras.Sequential([ - tf.keras.layers.Dense(64, activation="relu"), - tf.keras.layers.Dense(64, activation="relu"), - edl.layers.DenseNormalGamma(1), - ]) - - # Custom loss function to handle the custom regularizer coefficient - def EvidentialRegressionLoss(true, pred): - return edl.losses.EvidentialRegression(true, pred, coeff=1e-2) - - # Compile and fit the model! - model.compile( - optimizer=tf.keras.optimizers.Adam(5e-4), - loss=EvidentialRegressionLoss) - model.fit(x_train, y_train, batch_size=100, epochs=500) - - # Predict and plot using the trained model - y_pred = model(x_test) - plot_predictions(x_train, y_train, x_test, y_test, y_pred) - - # Done!! - - -#### Helper functions #### -def my_data(x_min, x_max, n, train=True): - x = np.linspace(x_min, x_max, n) - x = np.expand_dims(x, -1).astype(np.float32) - - sigma = 3 * np.ones_like(x) if train else np.zeros_like(x) - y = x**3 + np.random.normal(0, sigma).astype(np.float32) - - return x, y - -def plot_predictions(x_train, y_train, x_test, y_test, y_pred, n_stds=4, kk=0): - x_test = x_test[:, 0] - mu, v, alpha, beta = tf.split(y_pred, 4, axis=-1) - mu = mu[:, 0] - var = np.sqrt(beta / (v * (alpha - 1))) - var = np.minimum(var, 1e3)[:, 0] # for visualization - - plt.figure(figsize=(5, 3), dpi=200) - plt.scatter(x_train, y_train, s=1., c='#463c3c', zorder=0, label="Train") - plt.plot(x_test, y_test, 'r--', zorder=2, label="True") - plt.plot(x_test, mu, color='#007cab', zorder=3, label="Pred") - plt.plot([-4, -4], [-150, 150], 'k--', alpha=0.4, zorder=0) - plt.plot([+4, +4], [-150, 150], 'k--', alpha=0.4, zorder=0) - for k in np.linspace(0, n_stds, 4): - plt.fill_between( - x_test, (mu - k * var), (mu + k * var), - alpha=0.3, - edgecolor=None, - facecolor='#00aeef', - linewidth=0, - zorder=1, - label="Unc." if k == 0 else None) - plt.gca().set_ylim(-150, 150) - plt.gca().set_xlim(-7, 7) - plt.legend(loc="upper left") - plt.show() - - -if __name__ == "__main__": - main() diff --git a/pytorch_environment.yml b/pytorch_environment.yml new file mode 100644 index 0000000..54bf83e --- /dev/null +++ b/pytorch_environment.yml @@ -0,0 +1,13 @@ +name: edl +channels: + - conda-forge + - defaults +dependencies: + - python=3.8 + - mkl + - pytorch-lightning + - numpy + - matplotlib + - torchvision + - scipy + diff --git a/pytorch_validation/data/one.jpg b/pytorch_validation/data/one.jpg new file mode 100644 index 0000000..2578d26 Binary files /dev/null and b/pytorch_validation/data/one.jpg differ diff --git a/pytorch_validation/datamodules/__init__.py b/pytorch_validation/datamodules/__init__.py new file mode 100644 index 0000000..c6860b1 --- /dev/null +++ b/pytorch_validation/datamodules/__init__.py @@ -0,0 +1 @@ +from .mnist import MNISTDataModule diff --git a/pytorch_validation/datamodules/mnist.py b/pytorch_validation/datamodules/mnist.py new file mode 100644 index 0000000..3dfedc1 --- /dev/null +++ b/pytorch_validation/datamodules/mnist.py @@ -0,0 +1,41 @@ +import pytorch_lightning as pl +from torchvision.datasets.mnist import MNIST +import torchvision.transforms as transforms +from torch.utils.data import DataLoader, random_split + + +class MNISTDataModule(pl.LightningDataModule): + def __init__(self, batch_size=32, num_workers=1): + super(MNISTDataModule, self).__init__() + self.batch_size = batch_size + self.num_workers = num_workers + + def setup(self, stage=None): + full_mnist = MNIST("./data/mnist", + download=True, + train=True, + transform=transforms.Compose([ + transforms.Resize((28, 28)), + transforms.ToTensor()])) + test_size = int(0.95 * len(full_mnist)) + train_size = len(full_mnist) - test_size + self.data_train, self.data_val = random_split(full_mnist, [test_size, train_size]) + + self.data_test = MNIST("./data/mnist", + train=False, + download=True, + transform=transforms.Compose([ + transforms.Resize((28, 28)), + transforms.ToTensor()])) + + def train_dataloader(self): + return DataLoader(self.data_train, shuffle=True, + num_workers=self.num_workers, batch_size=self.batch_size, pin_memory=False) + + def test_dataloader(self): + return DataLoader(self.data_test, shuffle=False, + num_workers=self.num_workers, batch_size=self.batch_size, pin_memory=False) + + def val_dataloader(self): + return DataLoader(self.data_val, shuffle=False, + num_workers=self.num_workers, batch_size=self.batch_size, pin_memory=False) \ No newline at end of file diff --git a/pytorch_validation/models/__init__.py b/pytorch_validation/models/__init__.py new file mode 100644 index 0000000..a3c6688 --- /dev/null +++ b/pytorch_validation/models/__init__.py @@ -0,0 +1 @@ +from .lenet import LeNet diff --git a/pytorch_validation/models/lenet.py b/pytorch_validation/models/lenet.py new file mode 100644 index 0000000..0dde774 --- /dev/null +++ b/pytorch_validation/models/lenet.py @@ -0,0 +1,92 @@ +# MIT License +# +# Copyright (c) 2019 Douglas Brion +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +import torch.nn as nn +import torchvision +import pytorch_lightning as pl +from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts + +from evidential_deep_learning.pytorch.losses import Dirichlet_SOS, Dirichlet_Predictions, Dirichlet_Uncertainty + + +class LeNet(pl.LightningModule): + def __init__(self, dropout=False): + super(LeNet, self).__init__() + self.use_dropout = dropout + self.model = nn.Sequential( + nn.Conv2d(1, 20, kernel_size=5), + nn.MaxPool2d(1), + nn.ReLU(), + nn.Conv2d(20, 50, kernel_size=5), + nn.MaxPool2d(1), + nn.ReLU(), + nn.Flatten(), + nn.Linear(20000, 500), + nn.ReLU(), + nn.Dropout(0.2) if dropout else nn.Identity(), + nn.Linear(500, 10) + ) + self.accuracy = pl.metrics.Accuracy() + + def forward(self, x): + return self.model(x) + + def shared_inference_step(self, batch): + inputs, labels = batch + y = torch.eye(10, device=self.device)[labels] + outputs = self(inputs) + loss = Dirichlet_SOS(y, outputs, device=self.device) + return loss, outputs + + def training_step(self, batch, batch_idx): + loss, _ = self.shared_inference_step(batch) + self.log('train_loss', loss) + return loss + + def validation_step(self, batch, batch_idx): + loss, outputs = self.shared_inference_step(batch) + + # Logging + self.log('val_loss', loss) + predictions = Dirichlet_Predictions(outputs) + u = Dirichlet_Uncertainty(outputs) + self.log('accuracy', self.accuracy(predictions, batch[1]), on_step=True, on_epoch=True, logger=True) + self.log('mean_uncertainty', u.mean()) + tb = self.logger.experiment + images = batch[0][:9] + grid = torchvision.utils.make_grid(images, nrow=3) + tb.add_image('images', grid, self.global_step) + tb.add_graph(self, images) + + if batch_idx == 0: + for name, param in self.named_parameters(): + tb.add_histogram(name, param.clone().cpu().data.numpy(), self.current_epoch) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4) + scheduler = { + 'scheduler': CosineAnnealingWarmRestarts(optimizer, T_0=3), + 'interval': 'step' + } + return [optimizer], [scheduler] diff --git a/pytorch_validation/validate.py b/pytorch_validation/validate.py new file mode 100644 index 0000000..60ccae0 --- /dev/null +++ b/pytorch_validation/validate.py @@ -0,0 +1,124 @@ +from datamodules import MNISTDataModule +from models import LeNet +from torchvision import transforms +from evidential_deep_learning.pytorch.losses import * +import pytorch_lightning as pl +from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor +import matplotlib.pyplot as plt +import scipy.ndimage as nd +from torchvision.datasets.mnist import MNIST + + +if __name__ == '__main__': + model: pl.LightningModule = LeNet(dropout=True) + dm = MNISTDataModule(batch_size=32, num_workers=4) + + logger = pl_loggers.TensorBoardLogger('logs') + lr_monitor = LearningRateMonitor(logging_interval='step') + trainer = pl.Trainer(max_epochs=50, logger=logger, accumulate_grad_batches=16, callbacks=[lr_monitor]) + trainer.fit(model, datamodule=dm) + + model.eval() + model.freeze() + + # MIT License + # + # Copyright (c) 2019 Douglas Brion + # + # Permission is hereby granted, free of charge, to any person obtaining a copy + # of this software and associated documentation files (the "Software"), to deal + # in the Software without restriction, including without limitation the rights + # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + # copies of the Software, and to permit persons to whom the Software is + # furnished to do so, subject to the following conditions: + # + # The above copyright notice and this permission notice shall be included in all + # copies or substantial portions of the Software. + # + # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + # SOFTWARE. + + num_classes = 10 + Mdeg = 180 + Ndeg = int(Mdeg / 10) + 1 + ldeg = [] + lp = [] + lu = [] + classifications = [] + + scores = np.zeros((1, num_classes)) + rimgs = np.zeros((28, 28 * Ndeg)) + + data_val = MNIST("./data/mnist", + train=False, + download=True, + transform=transforms.Compose([ + transforms.Resize((28, 28)), + transforms.ToTensor()])) + + def rotate_img(x, deg): + return nd.rotate(x.reshape(28, 28), deg, reshape=False).ravel() + + img, _ = data_val[5] + with torch.no_grad(): + for i, deg in enumerate(np.linspace(0, Mdeg, Ndeg)): + nimg = rotate_img(img.numpy()[0], deg).reshape(28, 28) + + nimg = np.clip(a=nimg, a_min=0, a_max=1) + + rimgs[:, i*28:(i+1)*28] = nimg + trans = transforms.ToTensor() + img_tensor = trans(nimg) + img_tensor.unsqueeze_(0) + + outputs = model(img_tensor) + uncertainty = Dirichlet_Uncertainty(outputs) + alpha = Dirichlet_Evidence(outputs) + 1 + preds = Dirichlet_Predictions(outputs) + prob = alpha / torch.sum(alpha, dim=1, keepdim=True) + output = outputs.flatten() + prob = prob.flatten() + preds = preds.flatten() + classifications.append(preds[0].item()) + lu.append(uncertainty.mean()) + + scores += prob.detach().cpu().numpy() >= 0.5 + ldeg.append(deg) + lp.append(prob.tolist()) + + labels = np.arange(10)[scores[0].astype(bool)] + lp = np.array(lp)[:, labels] + c = ["black", "blue", "red", "brown", "purple", "cyan"] + marker = ["s", "^", "o"]*2 + labels = labels.tolist() + fig, axs = plt.subplots(3, gridspec_kw={"height_ratios": [4, 1, 12]}) + + for i in range(len(labels)): + axs[2].plot(ldeg, lp[:, i], marker=marker[i], c=c[i]) + + labels += ["uncertainty"] + axs[2].plot(ldeg, lu, marker="<", c="red") + + print(classifications) + + axs[0].set_title("Rotated \"1\" Digit Classifications") + axs[0].imshow(1 - rimgs, cmap="gray") + axs[0].axis("off") + + empty_lst = [classifications] + axs[1].table(cellText=empty_lst, bbox=[0, 1.2, 1, 1]) + axs[1].axis("off") + + axs[2].legend(labels, loc='best') + axs[2].set_xlim([0, Mdeg]) + axs[2].set_ylim([0, 1]) + axs[2].set_xlabel("Rotation Degree") + axs[2].set_ylabel("Classification Probability") + + plt.savefig('pytorch_discrete_validation.png') diff --git a/setup.py b/setup.py index 82fca3f..5b07ddc 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ install_requires=[ "numpy", "matplotlib", - ], # Tensorflow must be installed manually + ], # Tensorflow or Pytorch must be installed manually python_requires='>=3.7', classifiers=[ "Programming Language :: Python",