1515from datasets_utils import combinations_grid
1616from torch .nn .functional import one_hot
1717from torch .testing ._comparison import assert_equal as _assert_equal , BooleanPair , NonePair , NumberPair , TensorLikePair
18- from torchvision .prototype import features
18+ from torchvision .prototype import datapoints
1919from torchvision .prototype .transforms .functional import convert_dtype_image_tensor , to_image_tensor
2020from torchvision .transforms .functional_tensor import _max_value as get_max_value
2121
@@ -238,7 +238,7 @@ def load(self, device):
238238
239239@dataclasses .dataclass
240240class ImageLoader (TensorLoader ):
241- color_space : features .ColorSpace
241+ color_space : datapoints .ColorSpace
242242 spatial_size : Tuple [int , int ] = dataclasses .field (init = False )
243243 num_channels : int = dataclasses .field (init = False )
244244
@@ -248,10 +248,10 @@ def __post_init__(self):
248248
249249
250250NUM_CHANNELS_MAP = {
251- features .ColorSpace .GRAY : 1 ,
252- features .ColorSpace .GRAY_ALPHA : 2 ,
253- features .ColorSpace .RGB : 3 ,
254- features .ColorSpace .RGB_ALPHA : 4 ,
251+ datapoints .ColorSpace .GRAY : 1 ,
252+ datapoints .ColorSpace .GRAY_ALPHA : 2 ,
253+ datapoints .ColorSpace .RGB : 3 ,
254+ datapoints .ColorSpace .RGB_ALPHA : 4 ,
255255}
256256
257257
@@ -265,7 +265,7 @@ def get_num_channels(color_space):
265265def make_image_loader (
266266 size = "random" ,
267267 * ,
268- color_space = features .ColorSpace .RGB ,
268+ color_space = datapoints .ColorSpace .RGB ,
269269 extra_dims = (),
270270 dtype = torch .float32 ,
271271 constant_alpha = True ,
@@ -276,9 +276,9 @@ def make_image_loader(
276276 def fn (shape , dtype , device ):
277277 max_value = get_max_value (dtype )
278278 data = torch .testing .make_tensor (shape , low = 0 , high = max_value , dtype = dtype , device = device )
279- if color_space in {features .ColorSpace .GRAY_ALPHA , features .ColorSpace .RGB_ALPHA } and constant_alpha :
279+ if color_space in {datapoints .ColorSpace .GRAY_ALPHA , datapoints .ColorSpace .RGB_ALPHA } and constant_alpha :
280280 data [..., - 1 , :, :] = max_value
281- return features .Image (data , color_space = color_space )
281+ return datapoints .Image (data , color_space = color_space )
282282
283283 return ImageLoader (fn , shape = (* extra_dims , num_channels , * size ), dtype = dtype , color_space = color_space )
284284
@@ -290,10 +290,10 @@ def make_image_loaders(
290290 * ,
291291 sizes = DEFAULT_SPATIAL_SIZES ,
292292 color_spaces = (
293- features .ColorSpace .GRAY ,
294- features .ColorSpace .GRAY_ALPHA ,
295- features .ColorSpace .RGB ,
296- features .ColorSpace .RGB_ALPHA ,
293+ datapoints .ColorSpace .GRAY ,
294+ datapoints .ColorSpace .GRAY_ALPHA ,
295+ datapoints .ColorSpace .RGB ,
296+ datapoints .ColorSpace .RGB_ALPHA ,
297297 ),
298298 extra_dims = DEFAULT_EXTRA_DIMS ,
299299 dtypes = (torch .float32 , torch .uint8 ),
@@ -306,7 +306,7 @@ def make_image_loaders(
306306make_images = from_loaders (make_image_loaders )
307307
308308
309- def make_image_loader_for_interpolation (size = "random" , * , color_space = features .ColorSpace .RGB , dtype = torch .uint8 ):
309+ def make_image_loader_for_interpolation (size = "random" , * , color_space = datapoints .ColorSpace .RGB , dtype = torch .uint8 ):
310310 size = _parse_spatial_size (size )
311311 num_channels = get_num_channels (color_space )
312312
@@ -318,24 +318,24 @@ def fn(shape, dtype, device):
318318 .resize ((width , height ))
319319 .convert (
320320 {
321- features .ColorSpace .GRAY : "L" ,
322- features .ColorSpace .GRAY_ALPHA : "LA" ,
323- features .ColorSpace .RGB : "RGB" ,
324- features .ColorSpace .RGB_ALPHA : "RGBA" ,
321+ datapoints .ColorSpace .GRAY : "L" ,
322+ datapoints .ColorSpace .GRAY_ALPHA : "LA" ,
323+ datapoints .ColorSpace .RGB : "RGB" ,
324+ datapoints .ColorSpace .RGB_ALPHA : "RGBA" ,
325325 }[color_space ]
326326 )
327327 )
328328
329329 image_tensor = convert_dtype_image_tensor (to_image_tensor (image_pil ).to (device = device ), dtype = dtype )
330330
331- return features .Image (image_tensor , color_space = color_space )
331+ return datapoints .Image (image_tensor , color_space = color_space )
332332
333333 return ImageLoader (fn , shape = (num_channels , * size ), dtype = dtype , color_space = color_space )
334334
335335
336336def make_image_loaders_for_interpolation (
337337 sizes = ((233 , 147 ),),
338- color_spaces = (features .ColorSpace .RGB ,),
338+ color_spaces = (datapoints .ColorSpace .RGB ,),
339339 dtypes = (torch .uint8 ,),
340340):
341341 for params in combinations_grid (size = sizes , color_space = color_spaces , dtype = dtypes ):
@@ -344,7 +344,7 @@ def make_image_loaders_for_interpolation(
344344
345345@dataclasses .dataclass
346346class BoundingBoxLoader (TensorLoader ):
347- format : features .BoundingBoxFormat
347+ format : datapoints .BoundingBoxFormat
348348 spatial_size : Tuple [int , int ]
349349
350350
@@ -362,11 +362,11 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
362362
363363def make_bounding_box_loader (* , extra_dims = (), format , spatial_size = "random" , dtype = torch .float32 ):
364364 if isinstance (format , str ):
365- format = features .BoundingBoxFormat [format ]
365+ format = datapoints .BoundingBoxFormat [format ]
366366 if format not in {
367- features .BoundingBoxFormat .XYXY ,
368- features .BoundingBoxFormat .XYWH ,
369- features .BoundingBoxFormat .CXCYWH ,
367+ datapoints .BoundingBoxFormat .XYXY ,
368+ datapoints .BoundingBoxFormat .XYWH ,
369+ datapoints .BoundingBoxFormat .CXCYWH ,
370370 }:
371371 raise pytest .UsageError (f"Can't make bounding box in format { format } " )
372372
@@ -378,19 +378,19 @@ def fn(shape, dtype, device):
378378 raise pytest .UsageError ()
379379
380380 if any (dim == 0 for dim in extra_dims ):
381- return features .BoundingBox (
381+ return datapoints .BoundingBox (
382382 torch .empty (* extra_dims , 4 , dtype = dtype , device = device ), format = format , spatial_size = spatial_size
383383 )
384384
385385 height , width = spatial_size
386386
387- if format == features .BoundingBoxFormat .XYXY :
387+ if format == datapoints .BoundingBoxFormat .XYXY :
388388 x1 = torch .randint (0 , width // 2 , extra_dims )
389389 y1 = torch .randint (0 , height // 2 , extra_dims )
390390 x2 = randint_with_tensor_bounds (x1 + 1 , width - x1 ) + x1
391391 y2 = randint_with_tensor_bounds (y1 + 1 , height - y1 ) + y1
392392 parts = (x1 , y1 , x2 , y2 )
393- elif format == features .BoundingBoxFormat .XYWH :
393+ elif format == datapoints .BoundingBoxFormat .XYWH :
394394 x = torch .randint (0 , width // 2 , extra_dims )
395395 y = torch .randint (0 , height // 2 , extra_dims )
396396 w = randint_with_tensor_bounds (1 , width - x )
@@ -403,7 +403,7 @@ def fn(shape, dtype, device):
403403 h = randint_with_tensor_bounds (1 , torch .minimum (cy , height - cy ) + 1 )
404404 parts = (cx , cy , w , h )
405405
406- return features .BoundingBox (
406+ return datapoints .BoundingBox (
407407 torch .stack (parts , dim = - 1 ).to (dtype = dtype , device = device ), format = format , spatial_size = spatial_size
408408 )
409409
@@ -416,7 +416,7 @@ def fn(shape, dtype, device):
416416def make_bounding_box_loaders (
417417 * ,
418418 extra_dims = DEFAULT_EXTRA_DIMS ,
419- formats = tuple (features .BoundingBoxFormat ),
419+ formats = tuple (datapoints .BoundingBoxFormat ),
420420 spatial_size = "random" ,
421421 dtypes = (torch .float32 , torch .int64 ),
422422):
@@ -456,7 +456,7 @@ def fn(shape, dtype, device):
456456 # The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values,
457457 # regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123
458458 data = torch .testing .make_tensor (shape , low = 0 , high = num_categories , dtype = torch .int64 , device = device ).to (dtype )
459- return features .Label (data , categories = categories )
459+ return datapoints .Label (data , categories = categories )
460460
461461 return LabelLoader (fn , shape = extra_dims , dtype = dtype , categories = categories )
462462
@@ -480,7 +480,7 @@ def fn(shape, dtype, device):
480480 # since `one_hot` only supports int64
481481 label = make_label_loader (extra_dims = extra_dims , categories = num_categories , dtype = torch .int64 ).load (device )
482482 data = one_hot (label , num_classes = num_categories ).to (dtype )
483- return features .OneHotLabel (data , categories = categories )
483+ return datapoints .OneHotLabel (data , categories = categories )
484484
485485 return OneHotLabelLoader (fn , shape = (* extra_dims , num_categories ), dtype = dtype , categories = categories )
486486
@@ -509,7 +509,7 @@ def make_detection_mask_loader(size="random", *, num_objects="random", extra_dim
509509
510510 def fn (shape , dtype , device ):
511511 data = torch .testing .make_tensor (shape , low = 0 , high = 2 , dtype = dtype , device = device )
512- return features .Mask (data )
512+ return datapoints .Mask (data )
513513
514514 return MaskLoader (fn , shape = (* extra_dims , num_objects , * size ), dtype = dtype )
515515
@@ -537,7 +537,7 @@ def make_segmentation_mask_loader(size="random", *, num_categories="random", ext
537537
538538 def fn (shape , dtype , device ):
539539 data = torch .testing .make_tensor (shape , low = 0 , high = num_categories , dtype = dtype , device = device )
540- return features .Mask (data )
540+ return datapoints .Mask (data )
541541
542542 return MaskLoader (fn , shape = (* extra_dims , * size ), dtype = dtype )
543543
@@ -583,7 +583,7 @@ class VideoLoader(ImageLoader):
583583def make_video_loader (
584584 size = "random" ,
585585 * ,
586- color_space = features .ColorSpace .RGB ,
586+ color_space = datapoints .ColorSpace .RGB ,
587587 num_frames = "random" ,
588588 extra_dims = (),
589589 dtype = torch .uint8 ,
@@ -593,7 +593,7 @@ def make_video_loader(
593593
594594 def fn (shape , dtype , device ):
595595 video = make_image (size = shape [- 2 :], color_space = color_space , extra_dims = shape [:- 3 ], dtype = dtype , device = device )
596- return features .Video (video , color_space = color_space )
596+ return datapoints .Video (video , color_space = color_space )
597597
598598 return VideoLoader (
599599 fn , shape = (* extra_dims , num_frames , get_num_channels (color_space ), * size ), dtype = dtype , color_space = color_space
@@ -607,8 +607,8 @@ def make_video_loaders(
607607 * ,
608608 sizes = DEFAULT_SPATIAL_SIZES ,
609609 color_spaces = (
610- features .ColorSpace .GRAY ,
611- features .ColorSpace .RGB ,
610+ datapoints .ColorSpace .GRAY ,
611+ datapoints .ColorSpace .RGB ,
612612 ),
613613 num_frames = (1 , 0 , "random" ),
614614 extra_dims = DEFAULT_EXTRA_DIMS ,
0 commit comments