Skip to content

Commit 885e3c2

Browse files
TheCodezfmassa
authored andcommitted
Support for returning multiple targets (#700)
1 parent 8ce0070 commit 885e3c2

File tree

1 file changed

+53
-14
lines changed

1 file changed

+53
-14
lines changed

torchvision/datasets/cityscapes.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,45 @@
77

88
class Cityscapes(data.Dataset):
99
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
10+
1011
Args:
1112
root (string): Root directory of dataset where directory ``leftImg8bit``
1213
and ``gtFine`` or ``gtCoarse`` are located.
1314
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine"
1415
otherwise ``train``, ``train_extra`` or ``val``
1516
mode (string, optional): The quality mode to use, ``gtFine`` or ``gtCoarse``
16-
target_type (string, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
17-
or ``color``
17+
target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
18+
or ``color``. Can also be a list to output a tuple with all specified target types.
1819
transform (callable, optional): A function/transform that takes in a PIL image
1920
and returns a transformed version. E.g, ``transforms.RandomCrop``
2021
target_transform (callable, optional): A function/transform that takes in the
2122
target and transforms it.
23+
24+
Examples:
25+
26+
Get semantic segmentation target
27+
28+
.. code-block:: python
29+
dataset = Cityscapes('./data/cityscapes', split='train', mode='gtFine',
30+
target_type='semantic')
31+
32+
img, smnt = dataset[0]
33+
34+
Get multiple targets
35+
36+
.. code-block:: python
37+
dataset = Cityscapes('./data/cityscapes', split='train', mode='gtFine',
38+
target_type=['instance', 'color', 'polygon'])
39+
40+
img, (inst, col, poly) = dataset[0]
41+
42+
Validate on the "gtCoarse" set
43+
44+
.. code-block:: python
45+
dataset = Cityscapes('./data/cityscapes', split='val', mode='gtCoarse',
46+
target_type='semantic')
47+
48+
img, smnt = dataset[0]
2249
"""
2350

2451
def __init__(self, root, split='train', mode='gtFine', target_type='instance',
@@ -44,9 +71,12 @@ def __init__(self, root, split='train', mode='gtFine', target_type='instance',
4471
raise ValueError('Invalid split for mode "gtCoarse"! Please use split="train", split="train_extra"'
4572
' or split="val"')
4673

47-
if target_type not in ['instance', 'semantic', 'polygon', 'color']:
48-
raise ValueError('Invalid value for "target_type"! Please use target_type="instance",'
49-
' target_type="semantic", target_type="polygon" or target_type="color"')
74+
if not isinstance(target_type, list):
75+
self.target_type = [target_type]
76+
77+
if not all(t in ['instance', 'semantic', 'polygon', 'color'] for t in self.target_type):
78+
raise ValueError('Invalid value for "target_type"! Valid values are: "instance", "semantic", "polygon"'
79+
' or "color"')
5080

5181
if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
5282
raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
@@ -56,27 +86,36 @@ def __init__(self, root, split='train', mode='gtFine', target_type='instance',
5686
img_dir = os.path.join(self.images_dir, city)
5787
target_dir = os.path.join(self.targets_dir, city)
5888
for file_name in os.listdir(img_dir):
59-
target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
60-
self._get_target_suffix(self.mode, self.target_type))
89+
target_types = []
90+
for t in self.target_type:
91+
target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
92+
self._get_target_suffix(self.mode, t))
93+
target_types.append(os.path.join(target_dir, target_name))
6194

6295
self.images.append(os.path.join(img_dir, file_name))
63-
self.targets.append(os.path.join(target_dir, target_name))
96+
self.targets.append(target_types)
6497

6598
def __getitem__(self, index):
6699
"""
67100
Args:
68101
index (int): Index
69102
Returns:
70-
tuple: (image, target) where target is a json object if target_type="polygon",
71-
otherwise the image segmentation.
103+
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
104+
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
72105
"""
73106

74107
image = Image.open(self.images[index]).convert('RGB')
75108

76-
if self.target_type == 'polygon':
77-
target = self._load_json(self.targets[index])
78-
else:
79-
target = Image.open(self.targets[index])
109+
targets = []
110+
for i, t in enumerate(self.target_type):
111+
if t == 'polygon':
112+
target = self._load_json(self.targets[index][i])
113+
else:
114+
target = Image.open(self.targets[index][i])
115+
116+
targets.append(target)
117+
118+
target = tuple(targets) if len(targets) > 1 else targets[0]
80119

81120
if self.transform:
82121
image = self.transform(image)

0 commit comments

Comments
 (0)