diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index e56da8bbacc..8feab0ec574 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -373,8 +373,8 @@ def fn(shape, dtype, device): h = randint_with_tensor_bounds(1, height - y) parts = (x, y, w, h) else: # format == features.BoundingBoxFormat.CXCYWH: - cx = torch.randint(1, width - 1, ()) - cy = torch.randint(1, height - 1, ()) + cx = torch.randint(1, width - 1, extra_dims) + cy = torch.randint(1, height - 1, extra_dims) w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1) h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1) parts = (cx, cy, w, h)