Skip to content

Commit b06ea39

Browse files
Assert RAFT input resolution is 128 x 128 or higher (#7339)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent 62c2231 commit b06ea39

File tree

1 file changed

+13
-0
lines changed
  • torchvision/models/optical_flow

1 file changed

+13
-0
lines changed

torchvision/models/optical_flow/raft.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,19 @@ def build_pyramid(self, fmap1, fmap2):
369369
raise ValueError(
370370
f"Input feature maps should have the same shape, instead got {fmap1.shape} (fmap1.shape) != {fmap2.shape} (fmap2.shape)"
371371
)
372+
373+
# Explaining min_fmap_size below: the fmaps are down-sampled (num_levels - 1) times by a factor of 2.
374+
# The last corr_volume most have at least 2 values (hence the 2* factor), otherwise grid_sample() would
375+
# produce nans in its output.
376+
min_fmap_size = 2 * (2 ** (self.num_levels - 1))
377+
if any(fmap_size < min_fmap_size for fmap_size in fmap1.shape[-2:]):
378+
raise ValueError(
379+
"Feature maps are too small to be down-sampled by the correlation pyramid. "
380+
f"H and W of feature maps should be at least {min_fmap_size}; got: {fmap1.shape[-2:]}. "
381+
"Remember that input images to the model are downsampled by 8, so that means their "
382+
f"dimensions should be at least 8 * {min_fmap_size} = {8 * min_fmap_size}."
383+
)
384+
372385
corr_volume = self._compute_corr_volume(fmap1, fmap2)
373386

374387
batch_size, h, w, num_channels, _, _ = corr_volume.shape # _, _ = h, w

0 commit comments

Comments
 (0)