Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
dfb2862
replace most asserts with exceptions
jdsgomes Mar 10, 2022
40d0528
fix formating issues
jdsgomes Mar 10, 2022
13bfd80
fix linting and remove more asserts
jdsgomes Mar 11, 2022
45ecd61
Merge branch 'main' into replace_asserts_with_exceptions
jdsgomes Mar 11, 2022
f522368
fix regresion
jdsgomes Mar 11, 2022
23bd022
Merge branch 'replace_asserts_with_exceptions' of github.com:jdsgomes…
jdsgomes Mar 11, 2022
6a87e4d
fix regresion
jdsgomes Mar 11, 2022
e179358
fix bug
jdsgomes Mar 11, 2022
30b1714
apply ufmt
jdsgomes Mar 11, 2022
488d2af
apply ufmt
jdsgomes Mar 11, 2022
38d2d01
fix tests
jdsgomes Mar 11, 2022
7d42574
fix format
jdsgomes Mar 11, 2022
dc6856b
fix None check
jdsgomes Mar 11, 2022
2c56adc
fix detection models tests
jdsgomes Mar 11, 2022
aebca6d
non scriptable any
jdsgomes Mar 11, 2022
d54b582
add more checks for None values
jdsgomes Mar 13, 2022
36d2174
Merge branch 'main' into replace_asserts_with_exceptions
jdsgomes Mar 13, 2022
98c2702
fix retinanet test
jdsgomes Mar 13, 2022
bdab5f4
Merge branch 'replace_asserts_with_exceptions' of github.com:jdsgomes…
jdsgomes Mar 13, 2022
4900653
fix retinanet test
jdsgomes Mar 13, 2022
d5ccbf1
Update references/classification/transforms.py
jdsgomes Mar 14, 2022
de2f4b7
Update references/classification/transforms.py
jdsgomes Mar 14, 2022
275012a
Update references/optical_flow/transforms.py
jdsgomes Mar 14, 2022
fddd2ac
Update references/optical_flow/transforms.py
jdsgomes Mar 14, 2022
7e60b46
Update references/optical_flow/transforms.py
jdsgomes Mar 14, 2022
6c2e94f
make value checks more pythonic:
jdsgomes Mar 14, 2022
0a78c6b
fix merge
jdsgomes Mar 14, 2022
cb95c97
Update references/optical_flow/transforms.py
jdsgomes Mar 14, 2022
ff8f557
make value checks more pythonic
jdsgomes Mar 14, 2022
0598990
Merge branch 'replace_asserts_with_exceptions' of github.com:jdsgomes…
jdsgomes Mar 14, 2022
abafdb2
make more checks pythonic
jdsgomes Mar 14, 2022
5b30ce3
fix bug
jdsgomes Mar 14, 2022
ade3364
appy ufmt
jdsgomes Mar 14, 2022
2f4ecc1
fix tracing issues
jdsgomes Mar 14, 2022
981617b
fib typos
jdsgomes Mar 14, 2022
fec7d4b
fix lint
jdsgomes Mar 14, 2022
bdd913b
Merge branch 'main' into replace_asserts_with_exceptions
jdsgomes Mar 14, 2022
ca59cd7
remove unecessary f-strings
jdsgomes Mar 14, 2022
3391f00
Merge branch 'replace_asserts_with_exceptions' of github.com:jdsgomes…
jdsgomes Mar 14, 2022
81ac57c
fix bug
jdsgomes Mar 14, 2022
7affc95
Merge branch 'main' into replace_asserts_with_exceptions
jdsgomes Mar 14, 2022
8dc76e2
Merge branch 'main' into replace_asserts_with_exceptions
jdsgomes Mar 15, 2022
e68c1be
Update torchvision/datasets/mnist.py
jdsgomes Mar 15, 2022
d92a0f9
Update torchvision/datasets/mnist.py
jdsgomes Mar 15, 2022
4a30fc9
Update torchvision/ops/boxes.py
jdsgomes Mar 15, 2022
e4f214d
Update torchvision/ops/poolers.py
jdsgomes Mar 15, 2022
9e9ca6d
Update torchvision/utils.py
jdsgomes Mar 15, 2022
b234e08
address PR comments
jdsgomes Mar 15, 2022
1a45e1e
Update torchvision/io/_video_opt.py
jdsgomes Mar 15, 2022
8437088
Update torchvision/models/detection/generalized_rcnn.py
jdsgomes Mar 15, 2022
cff417a
Update torchvision/models/feature_extraction.py
jdsgomes Mar 15, 2022
1d9e3d3
Update torchvision/models/optical_flow/raft.py
jdsgomes Mar 15, 2022
ce06c29
address PR comments
jdsgomes Mar 15, 2022
2b1870f
addressing further pr comments
jdsgomes Mar 15, 2022
851adb2
fix bug
jdsgomes Mar 15, 2022
a915f1f
remove unecessary else
jdsgomes Mar 15, 2022
f41e115
apply ufmt
jdsgomes Mar 15, 2022
ee21d2e
last pr comment
jdsgomes Mar 15, 2022
d000238
replace RuntimeErrors
jdsgomes Mar 15, 2022
d77739b
Merge branch 'main' into replace_asserts_with_exceptions
jdsgomes Mar 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions references/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ class RandomMixup(torch.nn.Module):

def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero."
if not num_classes > 0:
raise ValueError("Please provide a valid positive value for the num_classes.")

if not alpha > 0:
raise ValueError("Alpha param can't be zero.")

self.num_classes = num_classes
self.p = p
Expand Down Expand Up @@ -99,8 +102,10 @@ class RandomCutmix(torch.nn.Module):

def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero."
if not num_classes > 0:
raise ValueError("Please provide a valid positive value for the num_classes.")
if not alpha > 0:
raise ValueError("Alpha param can't be zero.")

self.num_classes = num_classes
self.p = p
Expand Down
3 changes: 2 additions & 1 deletion references/detection/coco_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

class CocoEvaluator:
def __init__(self, coco_gt, iou_types):
assert isinstance(iou_types, (list, tuple))
if not isinstance(iou_types, (list, tuple)):
raise TypeError(f"This constructor expects iou_types of type list or tuple, instead got {type(iou_types)}")
coco_gt = copy.deepcopy(coco_gt)
self.coco_gt = coco_gt

Expand Down
5 changes: 4 additions & 1 deletion references/detection/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ def _has_valid_annotation(anno):
return True
return False

assert isinstance(dataset, torchvision.datasets.CocoDetection)
if not isinstance(dataset, torchvision.datasets.CocoDetection):
raise TypeError(
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
)
ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
Expand Down
22 changes: 14 additions & 8 deletions references/optical_flow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,21 @@ class ValidateModelInput(torch.nn.Module):
# Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects
def forward(self, img1, img2, flow, valid_flow_mask):

assert all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None)
assert all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None)
if not all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None):
raise TypeError("This method expects all input arguments to be of type torch.Tensor.")
if not all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None):
raise TypeError("This method expects the tensors img1, img2 and flow of be of dtype torch.float32.")

assert img1.shape == img2.shape
if not img1.shape == img2.shape:
raise ValueError("img1 and img2 should have the same shape.")
h, w = img1.shape[-2:]
if flow is not None:
assert flow.shape == (2, h, w)
if flow is not None and not flow.shape == (2, h, w):
raise ValueError(f"flow.shape should be (2, {h}, {w}) instead of {flow.shape}")
if valid_flow_mask is not None:
assert valid_flow_mask.shape == (h, w)
assert valid_flow_mask.dtype == torch.bool
if not valid_flow_mask.shape == (h, w):
raise ValueError(f"valid_flow_mask.shape should be ({h}, {w}) instead of {valid_flow_mask.shape}")
if not valid_flow_mask.dtype == torch.bool:
raise TypeError("valid_flow_mask should be of dtype torch.bool instead of {valid_flow_mask.dtype}")

return img1, img2, flow, valid_flow_mask

Expand Down Expand Up @@ -109,7 +114,8 @@ class RandomErasing(T.RandomErasing):
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False, max_erase=1):
super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace)
self.max_erase = max_erase
assert self.max_erase > 0
if not self.max_erase > 0:
raise ValueError("max_raise should be greater than 0")

def forward(self, img1, img2, flow, valid_flow_mask):
if torch.rand(1) > self.p:
Expand Down
5 changes: 4 additions & 1 deletion references/optical_flow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
if not isinstance(v, (float, int)):
raise TypeError(
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
)
self.meters[k].update(v)

def __getattr__(self, attr):
Expand Down
6 changes: 5 additions & 1 deletion references/segmentation/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def _has_valid_annotation(anno):
# if more than 1k pixels occupied in the image
return sum(obj["area"] for obj in anno) > 1000

assert isinstance(dataset, torchvision.datasets.CocoDetection)
if not isinstance(dataset, torchvision.datasets.CocoDetection):
raise TypeError(
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
)

ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
Expand Down
5 changes: 4 additions & 1 deletion references/segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
if not isinstance(v, (float, int)):
raise TypeError(
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
)
self.meters[k].update(v)

def __getattr__(self, attr):
Expand Down
3 changes: 2 additions & 1 deletion references/similarity/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def __init__(self, groups, p, k):
self.groups = create_groups(groups, self.k)

# Ensures there are enough classes to sample from
assert len(self.groups) >= p
if not len(self.groups) >= p:
raise ValueError("There are not enought classes to sample from")

def __iter__(self):
# Shuffle samples within groups
Expand Down
5 changes: 4 additions & 1 deletion references/video_classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
if not isinstance(v, (float, int)):
raise TypeError(
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
)
self.meters[k].update(v)

def __getattr__(self, attr):
Expand Down
6 changes: 3 additions & 3 deletions test/test_backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,16 @@ def test_build_fx_feature_extractor(self, model_name):
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
)
# Check must specify return nodes
with pytest.raises(AssertionError):
with pytest.raises(RuntimeError):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether a ValueError makes more sense here, and almost everywhere else a RuntimeError is raised? Usually when there's an issue with user-provided input, it's either a TypeError or a ValueError.

Also, it helps a lot to use the match=expected_err_msg parameter of pytest.raises: we can make sure that the exception that gets raise is indeed the one we expect, and it's helpful to document the tests as well. We don't have to match the entire error message, sometimes just matching the relevant part is good enough. We could leave that part for a follow up though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was in doubt in this one. I was using a strict definition of ValueError: "Raised when an operation or function receives an argument that has the right type but an inappropriate value". And in this case each argument might have an appropriate value, but the combination of them is not appropriate. But I am ok with your interpretation as these are user inputs so still fall under ValueError.

As for the expected error messages I would leave it for another PR as it is outside the scope of this one.

self._create_feature_extractor(model)
# Check return_nodes and train_return_nodes / eval_return nodes
# mutual exclusivity
with pytest.raises(AssertionError):
with pytest.raises(RuntimeError):
self._create_feature_extractor(
model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes
)
# Check train_return_nodes / eval_return nodes must both be specified
with pytest.raises(AssertionError):
with pytest.raises(RuntimeError):
self._create_feature_extractor(model, train_return_nodes=train_return_nodes)
# Check invalid node name raises ValueError
with pytest.raises(ValueError):
Expand Down
4 changes: 2 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,13 @@ def test_autocast(self, x_dtype, rois_dtype):

def _helper_boxes_shape(self, func):
# test boxes as Tensor[N, 5]
with pytest.raises(AssertionError):
with pytest.raises(ValueError):
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
func(a, boxes, output_size=(2, 2))

# test boxes as List[Tensor[N, 4]]
with pytest.raises(AssertionError):
with pytest.raises(ValueError):
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
ops.roi_pool(a, [boxes], output_size=(2, 2))
Expand Down
3 changes: 2 additions & 1 deletion torchvision/datasets/kinetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def __init__(
print("Using legacy structure")
self.split_folder = root
self.split = "unknown"
assert not download, "Cannot download the videos using legacy_structure."
if download:
raise RuntimeError("Cannot download the videos using legacy_structure.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be a ValueError rather than a RuntimeError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree - thanks for spotting these inconsistencies

else:
self.split_folder = path.join(root, split)
self.split = verify_str_arg(split, arg="split", valid_values=["train", "val"])
Expand Down
21 changes: 14 additions & 7 deletions torchvision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,14 @@ def _check_exists(self) -> bool:

def _load_data(self):
data = read_sn3_pascalvincent_tensor(self.images_file)
assert data.dtype == torch.uint8
assert data.ndimension() == 3
if not data.dtype == torch.uint8:
raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}")
if not data.ndimension() == 3:
raise ValueError("data should have 3 dimensions instead of {data.ndimension()}")

targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
assert targets.ndimension() == 2
if not targets.ndimension() == 2:
raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}")

if self.what == "test10k":
data = data[0:10000, :, :].clone()
Expand Down Expand Up @@ -530,13 +533,17 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso

def read_label_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False)
assert x.dtype == torch.uint8
assert x.ndimension() == 1
if not x.dtype == torch.uint8:
raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
if not x.ndimension() == 1:
raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}")
return x.long()


def read_image_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False)
assert x.dtype == torch.uint8
assert x.ndimension() == 3
if not x.dtype == torch.uint8:
raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
if not x.ndimension() == 3:
raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}")
return x
10 changes: 4 additions & 6 deletions torchvision/datasets/samplers/clip_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,10 @@ def __init__(
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
assert (
len(dataset) % group_size == 0
), "dataset length must be a multiplier of group size dataset length: %d, group size: %d" % (
len(dataset),
group_size,
)
if not len(dataset) % group_size == 0:
raise ValueError(
f"dataset length must be a multiplier of group size dataset length: {len(dataset)}, group size: {group_size}"
)
self.dataset = dataset
self.group_size = group_size
self.num_replicas = num_replicas
Expand Down
1 change: 0 additions & 1 deletion torchvision/datasets/sbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def __init__(

self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names]
assert len(self.images) == len(self.masks)

self._get_target = self._get_segmentation_target if self.mode == "segmentation" else self._get_boundaries_target

Expand Down
3 changes: 2 additions & 1 deletion torchvision/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> tor
`step` between windows. The distance between each element
in a window is given by `dilation`.
"""
assert tensor.dim() == 1
if not tensor.dim() == 1:
raise ValueError(f"tensor should have 1 dimension instead of {tensor.dim()}")
o_stride = tensor.stride(0)
numel = tensor.numel()
new_stride = (step * o_stride, dilation * o_stride)
Expand Down
10 changes: 3 additions & 7 deletions torchvision/io/_video_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,9 @@ def __init__(self) -> None:

def _validate_pts(pts_range: Tuple[int, int]) -> None:

if pts_range[1] > 0:
assert (
pts_range[0] <= pts_range[1]
), """Start pts should not be smaller than end pts, got
start pts: {:d} and end pts: {:d}""".format(
pts_range[0],
pts_range[1],
if pts_range[1] > 0 and not (pts_range[0] <= pts_range[1]):
raise ValueError(
f"Start pts should not be smaller than end pts, got start pts: {pts_range[0]} and end pts: {pts_range[1]}"
)


Expand Down
12 changes: 8 additions & 4 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,10 @@ def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
return targets

def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
assert isinstance(boxes, (list, tuple))
assert isinstance(rel_codes, torch.Tensor)
if not isinstance(boxes, (list, tuple)):
raise TypeError(f"This function expects boxes of type list or tuple, instead got {type(boxes)}")
if not isinstance(rel_codes, torch.Tensor):
raise TypeError(f"This function expects rel_codes of type torch.Tensor, instead got {type(rel_codes)}")
boxes_per_image = [b.size(0) for b in boxes]
concat_boxes = torch.cat(boxes, dim=0)
box_sum = 0
Expand Down Expand Up @@ -333,7 +335,8 @@ def __init__(self, high_threshold: float, low_threshold: float, allow_low_qualit
"""
self.BELOW_LOW_THRESHOLD = -1
self.BETWEEN_THRESHOLDS = -2
assert low_threshold <= high_threshold
if not low_threshold <= high_threshold:
raise ValueError("low_threshold should be <= high_threshold")
self.high_threshold = high_threshold
self.low_threshold = low_threshold
self.allow_low_quality_matches = allow_low_quality_matches
Expand Down Expand Up @@ -371,7 +374,8 @@ def __call__(self, match_quality_matrix: Tensor) -> Tensor:
matches[between_thresholds] = self.BETWEEN_THRESHOLDS

if self.allow_low_quality_matches:
assert all_matches is not None
if all_matches is None:
raise ValueError("all_matches should not be None")
self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)

return matches
Expand Down
58 changes: 29 additions & 29 deletions torchvision/models/detection/anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ def __init__(
if not isinstance(aspect_ratios[0], (list, tuple)):
aspect_ratios = (aspect_ratios,) * len(sizes)

assert len(sizes) == len(aspect_ratios)

self.sizes = sizes
self.aspect_ratios = aspect_ratios
self.cell_anchors = [
Expand Down Expand Up @@ -86,32 +84,34 @@ def num_anchors_per_location(self):
def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
anchors = []
cell_anchors = self.cell_anchors
assert cell_anchors is not None

if not (len(grid_sizes) == len(strides) == len(cell_anchors)):
raise ValueError(
"Anchors should be Tuple[Tuple[int]] because each feature "
"map could potentially have different sizes and aspect ratios. "
"There needs to be a match between the number of "
"feature maps passed and the number of sizes / aspect ratios specified."
)

for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
grid_height, grid_width = size
stride_height, stride_width = stride
device = base_anchors.device

# For output anchor, compute [x_center, y_center, x_center, y_center]
shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width
shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)

# For every (base anchor, output anchor) pair,
# offset each zero-centered base anchor by the center of the output anchor.
anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
if cell_anchors is None:
raise RuntimeError("cell_anchors should not be None")
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps we don't need the else block here, which would help minimize the changes?

if not (len(grid_sizes) == len(strides) == len(cell_anchors)):
raise ValueError(
"Anchors should be Tuple[Tuple[int]] because each feature "
"map could potentially have different sizes and aspect ratios. "
"There needs to be a match between the number of "
"feature maps passed and the number of sizes / aspect ratios specified."
)

for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
grid_height, grid_width = size
stride_height, stride_width = stride
device = base_anchors.device

# For output anchor, compute [x_center, y_center, x_center, y_center]
shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width
shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)

# For every (base anchor, output anchor) pair,
# offset each zero-centered base anchor by the center of the output anchor.
anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))

return anchors

Expand Down Expand Up @@ -164,8 +164,8 @@ def __init__(
clip: bool = True,
):
super().__init__()
if steps is not None:
assert len(aspect_ratios) == len(steps)
if steps is not None and len(aspect_ratios) != len(steps):
raise RuntimeError("aspect_ratios and steps should have the same length")
self.aspect_ratios = aspect_ratios
self.steps = steps
self.clip = clip
Expand Down
Loading