|
2 | 2 | import pathlib |
3 | 3 | import random |
4 | 4 | import re |
| 5 | +import textwrap |
5 | 6 | import warnings |
6 | 7 | from collections import defaultdict |
7 | 8 |
|
|
14 | 15 |
|
15 | 16 | from common_utils import ( |
16 | 17 | assert_equal, |
| 18 | + assert_run_python_script, |
17 | 19 | cpu_and_gpu, |
18 | 20 | make_bounding_box, |
19 | 21 | make_bounding_boxes, |
@@ -2045,3 +2047,52 @@ def test_sanitize_bounding_boxes_errors(): |
2045 | 2047 | ) |
2046 | 2048 | different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])} |
2047 | 2049 | transforms.SanitizeBoundingBoxes()(different_sizes) |
| 2050 | + |
| 2051 | + |
| 2052 | +@pytest.mark.parametrize( |
| 2053 | + "import_statement", |
| 2054 | + ( |
| 2055 | + "from torchvision.transforms import v2", |
| 2056 | + "import torchvision.transforms.v2", |
| 2057 | + "from torchvision.transforms.v2 import Resize", |
| 2058 | + "import torchvision.transforms.v2.functional", |
| 2059 | + "from torchvision.transforms.v2.functional import resize", |
| 2060 | + "from torchvision import datapoints", |
| 2061 | + "from torchvision.datapoints import Image", |
| 2062 | + "from torchvision.datasets import wrap_dataset_for_transforms_v2", |
| 2063 | + ), |
| 2064 | +) |
| 2065 | +@pytest.mark.parametrize("call_disable_warning", (True, False)) |
| 2066 | +def test_warnings_v2_namespaces(import_statement, call_disable_warning): |
| 2067 | + if call_disable_warning: |
| 2068 | + source = f""" |
| 2069 | + import warnings |
| 2070 | + import torchvision |
| 2071 | + torchvision.disable_beta_transforms_warning() |
| 2072 | + with warnings.catch_warnings(): |
| 2073 | + warnings.simplefilter("error") |
| 2074 | + {import_statement} |
| 2075 | + """ |
| 2076 | + else: |
| 2077 | + source = f""" |
| 2078 | + import pytest |
| 2079 | + with pytest.warns(UserWarning, match="v2 namespaces are still Beta"): |
| 2080 | + {import_statement} |
| 2081 | + """ |
| 2082 | + assert_run_python_script(textwrap.dedent(source)) |
| 2083 | + |
| 2084 | + |
| 2085 | +def test_no_warnings_v1_namespace(): |
| 2086 | + source = """ |
| 2087 | + import warnings |
| 2088 | + with warnings.catch_warnings(): |
| 2089 | + warnings.simplefilter("error") |
| 2090 | + import torchvision.transforms |
| 2091 | + from torchvision import transforms |
| 2092 | + import torchvision.transforms.functional |
| 2093 | + from torchvision.transforms import Resize |
| 2094 | + from torchvision.transforms.functional import resize |
| 2095 | + from torchvision import datasets |
| 2096 | + from torchvision.datasets import ImageNet |
| 2097 | + """ |
| 2098 | + assert_run_python_script(textwrap.dedent(source)) |
0 commit comments