diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index e682fda2cdc..c294777ee6f 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -369,6 +369,19 @@ def build_pyramid(self, fmap1, fmap2): raise ValueError( f"Input feature maps should have the same shape, instead got {fmap1.shape} (fmap1.shape) != {fmap2.shape} (fmap2.shape)" ) + + # Explaining min_fmap_size below: the fmaps are down-sampled (num_levels - 1) times by a factor of 2. + # The last corr_volume most have at least 2 values (hence the 2* factor), otherwise grid_sample() would + # produce nans in its output. + min_fmap_size = 2 * (2 ** (self.num_levels - 1)) + if any(fmap_size < min_fmap_size for fmap_size in fmap1.shape[-2:]): + raise ValueError( + "Feature maps are too small to be down-sampled by the correlation pyramid. " + f"H and W of feature maps should be at least {min_fmap_size}; got: {fmap1.shape[-2:]}. " + "Remember that input images to the model are downsampled by 8, so that means their " + f"dimensions should be at least 8 * {min_fmap_size} = {8 * min_fmap_size}." + ) + corr_volume = self._compute_corr_volume(fmap1, fmap2) batch_size, h, w, num_channels, _, _ = corr_volume.shape # _, _ = h, w