diff --git a/dataset/images.py b/dataset/images.py index bd3c552..759fd96 100644 --- a/dataset/images.py +++ b/dataset/images.py @@ -7,8 +7,11 @@ class ImagesDataset(Dataset): def __init__(self, root, mode='RGB', transforms=None): self.transforms = transforms self.mode = mode - self.filenames = sorted([*glob.glob(os.path.join(root, '**', '*.jpg'), recursive=True), - *glob.glob(os.path.join(root, '**', '*.png'), recursive=True)]) + if os.path.isfile(root): # if root is a file. + self.filenames = [root] + else: + self.filenames = sorted([*glob.glob(os.path.join(root, '**', '*.jpg'), recursive=True), + *glob.glob(os.path.join(root, '**', '*.png'), recursive=True)]) def __len__(self): return len(self.filenames)