Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions python/mxnet/gluon/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,7 @@ def transform_first(self, fn, lazy=True):
Dataset
The transformed dataset.
"""
def base_fn(x, *args):
if args:
return (fn(x),) + args
return fn(x)
return self.transform(base_fn, lazy)
return self.transform(_TransformFirstClosure(fn), lazy)


class SimpleDataset(Dataset):
Expand Down Expand Up @@ -129,6 +125,16 @@ def __getitem__(self, idx):
return self._fn(item)


class _TransformFirstClosure(object):
"""Use callable object instead of nested function, it can be pickled."""
def __init__(self, fn):
self._fn = fn

def __call__(self, x, *args):
if args:
return (self._fn(x),) + args
return self._fn(x)

class ArrayDataset(Dataset):
"""A dataset that combines multiple dataset-like objects, e.g.
Datasets, lists, arrays, etc.
Expand Down
12 changes: 6 additions & 6 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def _dataset_transform_fn(x, y):
"""Named transform function since lambda function cannot be pickled."""
return x, y

def _dataset_transform_first_fn(x):
"""Named transform function since lambda function cannot be pickled."""
return x

@with_seed()
def test_recordimage_dataset_with_data_loader_multiworker():
recfile = prepare_record()
Expand All @@ -95,17 +99,13 @@ def test_recordimage_dataset_with_data_loader_multiworker():
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i

# try limit recursion depth
import sys
old_limit = sys.getrecursionlimit()
sys.setrecursionlimit(500) # this should be smaller than any default value used in python
dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(_dataset_transform_fn)
# with transform_first
dataset = gluon.data.vision.ImageRecordDataset(recfile).transform_first(_dataset_transform_first_fn)
loader = gluon.data.DataLoader(dataset, 1, num_workers=5)

for i, (x, y) in enumerate(loader):
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i
sys.setrecursionlimit(old_limit)

@with_seed()
def test_sampler():
Expand Down