Skip to content

Commit aefe335

Browse files
author
KaiyangZhou
committed
use dict as return of dataloader
1 parent fe0f39f commit aefe335

3 files changed

Lines changed: 21 additions & 9 deletions

File tree

torchreid/data/datasets/dataset.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,13 @@ def __getitem__(self, index):
265265
img = read_image(img_path)
266266
if self.transform is not None:
267267
img = self.transform(img)
268-
return img, pid, camid, img_path
268+
item = {
269+
'img': img,
270+
'pid': pid,
271+
'camid': camid,
272+
'impath': img_path
273+
}
274+
return item
269275

270276
def show_summary(self):
271277
num_train_pids, num_train_cams = self.parse_data(self.train)
@@ -373,7 +379,13 @@ def __getitem__(self, index):
373379
imgs.append(img)
374380
imgs = torch.cat(imgs, dim=0)
375381

376-
return imgs, pid, camid
382+
item = {
383+
'img': imgs,
384+
'pid': pid,
385+
'camid': camid
386+
}
387+
388+
return item
377389

378390
def show_summary(self):
379391
num_train_pids, num_train_cams = self.parse_data(self.train)

torchreid/engine/engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -449,14 +449,14 @@ def extract_features(self, input):
449449
return self.model(input)
450450

451451
def parse_data_for_train(self, data):
452-
imgs = data[0]
453-
pids = data[1]
452+
imgs = data['img']
453+
pids = data['pid']
454454
return imgs, pids
455455

456456
def parse_data_for_eval(self, data):
457-
imgs = data[0]
458-
pids = data[1]
459-
camids = data[2]
457+
imgs = data['img']
458+
pids = data['pid']
459+
camids = data['camid']
460460
return imgs, pids, camids
461461

462462
def two_stepped_transfer_learning(

torchreid/engine/video/softmax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def __init__(
7878
self.pooling_method = pooling_method
7979

8080
def parse_data_for_train(self, data):
81-
imgs = data[0]
82-
pids = data[1]
81+
imgs = data['img']
82+
pids = data['pid']
8383
if imgs.dim() == 5:
8484
# b: batch size
8585
# s: sqeuence length

0 commit comments

Comments
 (0)