Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions torchvision/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,10 @@ def forward(self, image1, image2, num_flow_updates: int = 12):
raise ValueError(f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}")
if not (h % 8 == 0) and (w % 8 == 0):
raise ValueError(f"input image H and W should be divisible by 8, instead got {h} (h) and {w} (w)")
if h < 128 or w < 128:
raise ValueError(
f"input image H and W should be equal to or larger than 128, instead got {h} (h) and {w} (w)"
)

fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0))
fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0)
Expand Down