Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion ignite/metrics/gan/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class FID(_BaseInceptionMetric):

Remark:

This implementation is inspired by pytorch_fid package which can be found `here`__
This implementation is inspired by `pytorch_fid` package which can be found `here`__

__ https://github.com/mseitzer/pytorch-fid

Expand Down Expand Up @@ -114,6 +114,47 @@ class FID(_BaseInceptionMetric):

0.0

.. note::

The default `torchvision` model used is InceptionV3 pretrained on ImageNet.
This can lead to differences in results with `pytorch_fid`. To find comparable results,
the following model wrapper should be used:

.. code::

import torch.nn as nn

# wrapper class as feature_extractor
class WrapperInceptionV3(nn.Module):

def __init__(self, fid_incv3):
super().__init__()
self.fid_incv3 = fid_incv3

@torch.no_grad()
def forward(self, x):
y = self.fid_incv3(x)
y = y[0]
y = y[:, :, 0, 0]
return y

# use cpu rather than cuda to get comparable results
device = "cpu"

# pytorch_fid model
dims = 2048
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
model = InceptionV3([block_idx]).to(device)

# wrapper model to pytorch_fid model
wrapper_model = WrapperInceptionV3(model)
wrapper_model.eval();

# comparable metric
pytorch_fid_metric = FID(num_features=dims, feature_extractor=wrapper_model)

Important, `pytorch_fid` results depend on the batch size if the device is `cuda`.

.. versionadded:: 0.4.6
"""

Expand Down