Skip to content

Commit 1614608

Browse files
authored
fix drop_last not work on IterableDataset (#34801)
* fix drop_last not work in IterableDataset. test=develop
1 parent 181f7ce commit 1614608

File tree

5 files changed

+36
-8
lines changed

5 files changed

+36
-8
lines changed

python/paddle/fluid/dataloader/dataloader_iter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(self, loader):
5959
self._places = loader.places
6060
self._return_list = loader.return_list
6161
self._batch_sampler = loader.batch_sampler
62+
self._drop_last = loader.drop_last
6263
self._auto_collate_batch = loader.auto_collate_batch
6364
self._num_workers = loader.num_workers
6465
self._use_buffer_reader = loader.use_buffer_reader
@@ -111,7 +112,7 @@ def __init__(self, loader):
111112

112113
self._dataset_fetcher = _DatasetKind.create_fetcher(
113114
self._dataset_kind, self._dataset, self._auto_collate_batch,
114-
self._collate_fn, True)
115+
self._collate_fn, self._drop_last)
115116

116117
# NOTE: _structrue_infos used to record the data structure of
117118
# batch to restore batch structure after reading Tensor
@@ -309,8 +310,8 @@ def _init_workers(self):
309310
args=(self._dataset, self._dataset_kind, indices_queue,
310311
self._data_queue, self._workers_done_event,
311312
self._auto_collate_batch, self._collate_fn,
312-
self._worker_init_fn, i, self._num_workers,
313-
self._use_shared_memory))
313+
self._drop_last, self._worker_init_fn, i,
314+
self._num_workers, self._use_shared_memory))
314315
worker.daemon = True
315316
worker.start()
316317
self._workers.append(worker)

python/paddle/fluid/dataloader/worker.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def mix(x, y):
253253

254254

255255
def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
256-
auto_collate_batch, collate_fn, init_fn, worker_id,
256+
auto_collate_batch, collate_fn, drop_last, init_fn, worker_id,
257257
num_workers, use_shared_memory):
258258
try:
259259
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
@@ -282,8 +282,9 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
282282
try:
283283
if init_fn is not None:
284284
init_fn(worker_id)
285-
fetcher = _DatasetKind.create_fetcher(
286-
dataset_kind, dataset, auto_collate_batch, collate_fn, True)
285+
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset,
286+
auto_collate_batch,
287+
collate_fn, drop_last)
287288
except:
288289
init_exception = _WorkerException(worker_id)
289290

python/paddle/fluid/reader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ def __init__(self,
401401
shuffle=shuffle,
402402
drop_last=drop_last)
403403

404+
self.drop_last = drop_last
404405
self.auto_collate_batch = self.batch_sampler is not None
405406

406407
self.pin_memory = False

python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,5 +397,30 @@ def test_main(self):
397397
assert out == outp
398398

399399

400+
class TestDatasetWithDropLast(unittest.TestCase):
401+
def run_main(self, dataset, num_samples, batch_size):
402+
for num_workers in [0, 1]:
403+
for drop_last in [True, False]:
404+
steps = (num_samples + (1 - int(drop_last)) * \
405+
(batch_size - 1)) // batch_size
406+
dataloader = DataLoader(
407+
dataset,
408+
batch_size=batch_size,
409+
drop_last=drop_last,
410+
num_workers=num_workers)
411+
datas = []
412+
for data in dataloader:
413+
datas.append(data)
414+
assert len(datas) == steps
415+
416+
def test_map_dataset(self):
417+
dataset = RandomDataset(10)
418+
self.run_main(dataset, 10, 3)
419+
420+
def test_iterable_dataset(self):
421+
dataset = RandomIterableDataset(10)
422+
self.run_main(dataset, 10, 3)
423+
424+
400425
if __name__ == '__main__':
401426
unittest.main()

python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def _collate_fn(sample_list):
180180
indices_queue.put(None)
181181
_worker_loop(loader._dataset, 0, indices_queue,
182182
loader._data_queue, loader._workers_done_event,
183-
True, _collate_fn, _init_fn, 0, 1,
183+
True, _collate_fn, True, _init_fn, 0, 1,
184184
loader._use_shared_memory)
185185
self.assertTrue(False)
186186
except AssertionError:
@@ -224,7 +224,7 @@ def _collate_fn(sample_list):
224224
loader._workers_done_event.set()
225225
_worker_loop(loader._dataset, 0, indices_queue,
226226
loader._data_queue, loader._workers_done_event,
227-
True, _collate_fn, _init_fn, 0, 1,
227+
True, _collate_fn, True, _init_fn, 0, 1,
228228
loader._use_shared_memory)
229229
self.assertTrue(True)
230230
except AssertionError:

0 commit comments

Comments
 (0)