77
88class Cityscapes (data .Dataset ):
99 """`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
10+
1011 Args:
1112 root (string): Root directory of dataset where directory ``leftImg8bit``
1213 and ``gtFine`` or ``gtCoarse`` are located.
1314 split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine"
1415 otherwise ``train``, ``train_extra`` or ``val``
1516 mode (string, optional): The quality mode to use, ``gtFine`` or ``gtCoarse``
16- target_type (string, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
17- or ``color``
17+ target_type (string or list , optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
18+ or ``color``. Can also be a list to output a tuple with all specified target types.
1819 transform (callable, optional): A function/transform that takes in a PIL image
1920 and returns a transformed version. E.g, ``transforms.RandomCrop``
2021 target_transform (callable, optional): A function/transform that takes in the
2122 target and transforms it.
23+
24+ Examples:
25+
26+ Get semantic segmentation target
27+
28+ .. code-block:: python
29+ dataset = Cityscapes('./data/cityscapes', split='train', mode='gtFine',
30+ target_type='semantic')
31+
32+ img, smnt = dataset[0]
33+
34+ Get multiple targets
35+
36+ .. code-block:: python
37+ dataset = Cityscapes('./data/cityscapes', split='train', mode='gtFine',
38+ target_type=['instance', 'color', 'polygon'])
39+
40+ img, (inst, col, poly) = dataset[0]
41+
42+ Validate on the "gtCoarse" set
43+
44+ .. code-block:: python
45+ dataset = Cityscapes('./data/cityscapes', split='val', mode='gtCoarse',
46+ target_type='semantic')
47+
48+ img, smnt = dataset[0]
2249 """
2350
2451 def __init__ (self , root , split = 'train' , mode = 'gtFine' , target_type = 'instance' ,
@@ -44,9 +71,12 @@ def __init__(self, root, split='train', mode='gtFine', target_type='instance',
4471 raise ValueError ('Invalid split for mode "gtCoarse"! Please use split="train", split="train_extra"'
4572 ' or split="val"' )
4673
47- if target_type not in ['instance' , 'semantic' , 'polygon' , 'color' ]:
48- raise ValueError ('Invalid value for "target_type"! Please use target_type="instance",'
49- ' target_type="semantic", target_type="polygon" or target_type="color"' )
74+ if not isinstance (target_type , list ):
75+ self .target_type = [target_type ]
76+
77+ if not all (t in ['instance' , 'semantic' , 'polygon' , 'color' ] for t in self .target_type ):
78+ raise ValueError ('Invalid value for "target_type"! Valid values are: "instance", "semantic", "polygon"'
79+ ' or "color"' )
5080
5181 if not os .path .isdir (self .images_dir ) or not os .path .isdir (self .targets_dir ):
5282 raise RuntimeError ('Dataset not found or incomplete. Please make sure all required folders for the'
@@ -56,27 +86,36 @@ def __init__(self, root, split='train', mode='gtFine', target_type='instance',
5686 img_dir = os .path .join (self .images_dir , city )
5787 target_dir = os .path .join (self .targets_dir , city )
5888 for file_name in os .listdir (img_dir ):
59- target_name = '{}_{}' .format (file_name .split ('_leftImg8bit' )[0 ],
60- self ._get_target_suffix (self .mode , self .target_type ))
89+ target_types = []
90+ for t in self .target_type :
91+ target_name = '{}_{}' .format (file_name .split ('_leftImg8bit' )[0 ],
92+ self ._get_target_suffix (self .mode , t ))
93+ target_types .append (os .path .join (target_dir , target_name ))
6194
6295 self .images .append (os .path .join (img_dir , file_name ))
63- self .targets .append (os . path . join ( target_dir , target_name ) )
96+ self .targets .append (target_types )
6497
6598 def __getitem__ (self , index ):
6699 """
67100 Args:
68101 index (int): Index
69102 Returns:
70- tuple: (image, target) where target is a json object if target_type="polygon",
71- otherwise the image segmentation.
103+ tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
104+ than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
72105 """
73106
74107 image = Image .open (self .images [index ]).convert ('RGB' )
75108
76- if self .target_type == 'polygon' :
77- target = self ._load_json (self .targets [index ])
78- else :
79- target = Image .open (self .targets [index ])
109+ targets = []
110+ for i , t in enumerate (self .target_type ):
111+ if t == 'polygon' :
112+ target = self ._load_json (self .targets [index ][i ])
113+ else :
114+ target = Image .open (self .targets [index ][i ])
115+
116+ targets .append (target )
117+
118+ target = tuple (targets ) if len (targets ) > 1 else targets [0 ]
80119
81120 if self .transform :
82121 image = self .transform (image )
0 commit comments