From d8b4d0e83ed39bb111ec313dd3dd86151a8d09c0 Mon Sep 17 00:00:00 2001 From: Christoph Reich <34400551+ChristophReich1996@users.noreply.github.com> Date: Sat, 25 Feb 2023 21:55:52 +0100 Subject: [PATCH 1/7] Check that input resolution is 128 x 128 or higher --- torchvision/models/optical_flow/raft.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index 37da4ff0a44..84583d9c810 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -475,6 +475,8 @@ def forward(self, image1, image2, num_flow_updates: int = 12): raise ValueError(f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}") if not (h % 8 == 0) and (w % 8 == 0): raise ValueError(f"input image H and W should be divisible by 8, instead got {h} (h) and {w} (w)") + if not (h < 128) and (w < 128): + raise ValueError(f"input image H and W should be equal or larger than 128, instead got {h} (h) and {w} (w)") fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0)) fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0) From 8e2358fbe4d260526a35af1e5a93de6eed53c254 Mon Sep 17 00:00:00 2001 From: Christoph Reich <34400551+ChristophReich1996@users.noreply.github.com> Date: Mon, 27 Feb 2023 06:26:58 -0500 Subject: [PATCH 2/7] Fix error message Co-authored-by: Nicolas Hug --- torchvision/models/optical_flow/raft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index 84583d9c810..da0b8a0bea7 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -475,8 +475,8 @@ def forward(self, image1, image2, num_flow_updates: int = 12): raise ValueError(f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}") if not (h % 8 == 0) and (w % 8 == 0): raise ValueError(f"input image H and W should be divisible by 8, instead got {h} (h) and {w} (w)") - if not (h < 128) and (w < 128): - raise ValueError(f"input image H and W should be equal or larger than 128, instead got {h} (h) and {w} (w)") + if h < 128 or w < 128: + raise ValueError(f"input image H and W should be equal to or larger than 128, instead got {h} (h) and {w} (w)") fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0)) fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0) From 9d42a8f3534bdb4729c15249e741631132c625b3 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 27 Feb 2023 12:39:53 +0000 Subject: [PATCH 3/7] formatting --- torchvision/models/optical_flow/raft.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index da0b8a0bea7..fbee8860527 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -476,7 +476,9 @@ def forward(self, image1, image2, num_flow_updates: int = 12): if not (h % 8 == 0) and (w % 8 == 0): raise ValueError(f"input image H and W should be divisible by 8, instead got {h} (h) and {w} (w)") if h < 128 or w < 128: - raise ValueError(f"input image H and W should be equal to or larger than 128, instead got {h} (h) and {w} (w)") + raise ValueError( + f"input image H and W should be equal to or larger than 128, instead got {h} (h) and {w} (w)" + ) fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0)) fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0) From 930fcf706e8e14da94bf8323432c115bcfb9a688 Mon Sep 17 00:00:00 2001 From: Christoph Reich <34400551+ChristophReich1996@users.noreply.github.com> Date: Mon, 27 Feb 2023 16:08:40 +0100 Subject: [PATCH 4/7] Check resolution at the feature level w/ corr_block parameters --- torchvision/models/optical_flow/raft.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index fbee8860527..e4ebe3c7c64 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -475,16 +475,21 @@ def forward(self, image1, image2, num_flow_updates: int = 12): raise ValueError(f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}") if not (h % 8 == 0) and (w % 8 == 0): raise ValueError(f"input image H and W should be divisible by 8, instead got {h} (h) and {w} (w)") - if h < 128 or w < 128: - raise ValueError( - f"input image H and W should be equal to or larger than 128, instead got {h} (h) and {w} (w)" - ) fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0)) fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0) if fmap1.shape[-2:] != (h // 8, w // 8): raise ValueError("The feature encoder should downsample H and W by 8") - + + _, _, h_fmap, w_fmap = fmap1.shape + if not (((h_fmap // 2**(self.corr_block.num_levels - 1))) < 2) and ( + ((w_fmap // 2**(self.corr_block.num_levels - 1))) < 2 + ): + min_res = 2 * 2**(self.corr_block.num_levels - 1) * 8 + raise ValueError( + f"input image resolution is too small image resolution should be at least {min_res} (h) and {min_res} (w), got {h} (h) and {w} (w)" + ) + self.corr_block.build_pyramid(fmap1, fmap2) context_out = self.context_encoder(image1) From ca2134d5f16ad758faee1aceba9f8b9d5632bd4e Mon Sep 17 00:00:00 2001 From: ChristophReich1996 Date: Wed, 10 May 2023 16:48:43 +0200 Subject: [PATCH 5/7] Move min res check to build_pyramid and simplify --- torchvision/models/optical_flow/raft.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index b227f9a72b5..84587bf0658 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -369,6 +369,14 @@ def build_pyramid(self, fmap1, fmap2): raise ValueError( f"Input feature maps should have the same shape, instead got {fmap1.shape} (fmap1.shape) != {fmap2.shape} (fmap2.shape)" ) + + _, _, h_fmap, w_fmap = fmap1.shape + min_res = 2 * 2 ** (self.num_levels - 1) * 8 + if (min_res // 8, min_res // 8) <= (h_fmap, w_fmap): + raise ValueError( + f"Input image resolution is too small, resolution should be at least {min_res} (h) and {min_res} (w)" + ) + corr_volume = self._compute_corr_volume(fmap1, fmap2) batch_size, h, w, num_channels, _, _ = corr_volume.shape # _, _ = h, w @@ -481,15 +489,6 @@ def forward(self, image1, image2, num_flow_updates: int = 12): if fmap1.shape[-2:] != (h // 8, w // 8): raise ValueError("The feature encoder should downsample H and W by 8") - _, _, h_fmap, w_fmap = fmap1.shape - if not (((h_fmap // 2**(self.corr_block.num_levels - 1))) < 2) and ( - ((w_fmap // 2**(self.corr_block.num_levels - 1))) < 2 - ): - min_res = 2 * 2**(self.corr_block.num_levels - 1) * 8 - raise ValueError( - f"input image resolution is too small image resolution should be at least {min_res} (h) and {min_res} (w), got {h} (h) and {w} (w)" - ) - self.corr_block.build_pyramid(fmap1, fmap2) context_out = self.context_encoder(image1) From 3f353ec3537c4f338137d9c0baa78916b248dba9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 11 May 2023 11:02:32 +0000 Subject: [PATCH 6/7] Update logic --- torchvision/models/optical_flow/raft.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index 84587bf0658..f01ba465abf 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -370,11 +370,16 @@ def build_pyramid(self, fmap1, fmap2): f"Input feature maps should have the same shape, instead got {fmap1.shape} (fmap1.shape) != {fmap2.shape} (fmap2.shape)" ) - _, _, h_fmap, w_fmap = fmap1.shape - min_res = 2 * 2 ** (self.num_levels - 1) * 8 - if (min_res // 8, min_res // 8) <= (h_fmap, w_fmap): + # Explaining min_fmap_size below: the fmaps are down-sampled (num_levels - 1) times by a factor of 2. + # The resulting corr_volum most have at least 2 values (hence the 2* factor), otherwise grid_sample() would + # produce nans in its output. + min_fmap_size = 2 * (2 ** (self.num_levels - 1)) + if any(fmap_size < min_fmap_size for fmap_size in fmap1.shape[-2:]): raise ValueError( - f"Input image resolution is too small, resolution should be at least {min_res} (h) and {min_res} (w)" + f"Feature maps are too small to be down-sampled by the correlation pyramid. " + f"H and W of feature maps should be at least {min_fmap_size}; got: {fmap1.shape[-2:]}. " + "Remember that input images to the model are downsampled by 8, so that means their " + f"dimensions should be at least 8 * {min_fmap_size} = {8 * min_fmap_size}." ) corr_volume = self._compute_corr_volume(fmap1, fmap2) @@ -488,7 +493,7 @@ def forward(self, image1, image2, num_flow_updates: int = 12): fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0) if fmap1.shape[-2:] != (h // 8, w // 8): raise ValueError("The feature encoder should downsample H and W by 8") - + self.corr_block.build_pyramid(fmap1, fmap2) context_out = self.context_encoder(image1) From 1cbcadaf01fadf7e646b2e7067f045bd36af9b4a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 11 May 2023 11:07:33 +0000 Subject: [PATCH 7/7] lint --- torchvision/models/optical_flow/raft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index f01ba465abf..c294777ee6f 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -371,12 +371,12 @@ def build_pyramid(self, fmap1, fmap2): ) # Explaining min_fmap_size below: the fmaps are down-sampled (num_levels - 1) times by a factor of 2. - # The resulting corr_volum most have at least 2 values (hence the 2* factor), otherwise grid_sample() would + # The last corr_volume most have at least 2 values (hence the 2* factor), otherwise grid_sample() would # produce nans in its output. min_fmap_size = 2 * (2 ** (self.num_levels - 1)) if any(fmap_size < min_fmap_size for fmap_size in fmap1.shape[-2:]): raise ValueError( - f"Feature maps are too small to be down-sampled by the correlation pyramid. " + "Feature maps are too small to be down-sampled by the correlation pyramid. " f"H and W of feature maps should be at least {min_fmap_size}; got: {fmap1.shape[-2:]}. " "Remember that input images to the model are downsampled by 8, so that means their " f"dimensions should be at least 8 * {min_fmap_size} = {8 * min_fmap_size}."