Skip to content

Commit 930fcf7

Browse files
Check resolution at the feature level w/ corr_block parameters
1 parent 9d42a8f commit 930fcf7

File tree

1 file changed

+10
-5
lines changed
  • torchvision/models/optical_flow

1 file changed

+10
-5
lines changed

torchvision/models/optical_flow/raft.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -475,16 +475,21 @@ def forward(self, image1, image2, num_flow_updates: int = 12):
475475
raise ValueError(f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}")
476476
if not (h % 8 == 0) and (w % 8 == 0):
477477
raise ValueError(f"input image H and W should be divisible by 8, instead got {h} (h) and {w} (w)")
478-
if h < 128 or w < 128:
479-
raise ValueError(
480-
f"input image H and W should be equal to or larger than 128, instead got {h} (h) and {w} (w)"
481-
)
482478

483479
fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0))
484480
fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0)
485481
if fmap1.shape[-2:] != (h // 8, w // 8):
486482
raise ValueError("The feature encoder should downsample H and W by 8")
487-
483+
484+
_, _, h_fmap, w_fmap = fmap1.shape
485+
if not (((h_fmap // 2**(self.corr_block.num_levels - 1))) < 2) and (
486+
((w_fmap // 2**(self.corr_block.num_levels - 1))) < 2
487+
):
488+
min_res = 2 * 2**(self.corr_block.num_levels - 1) * 8
489+
raise ValueError(
490+
f"input image resolution is too small image resolution should be at least {min_res} (h) and {min_res} (w), got {h} (h) and {w} (w)"
491+
)
492+
488493
self.corr_block.build_pyramid(fmap1, fmap2)
489494

490495
context_out = self.context_encoder(image1)

0 commit comments

Comments
 (0)