Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ STL10
- ``download`` : ``True`` = downloads the dataset from the internet and
puts it in root directory. If dataset already downloaded, does not do
anything.

SVHN
~~~~~

``dset.SVHN(root, split='train', transform=None, target_transform=None, download=False)``

- ``root`` : root directory of dataset where there is folder ``SVHN``
- ``split`` : ``'train'`` = Training set, ``'test'`` = Test set, ``'extra'`` = Extra training set
- ``download`` : ``True`` = downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, does not do
anything.

ImageFolder
~~~~~~~~~~~
Expand Down
540 changes: 540 additions & 0 deletions test/sanity_checks1.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10
from .mnist import MNIST
from .svhn import SVHN

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100',
'MNIST', 'STL10')
'MNIST', 'STL10', 'SVHN')
111 changes: 111 additions & 0 deletions torchvision/datasets/svhn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import numpy as np
import sys


class SVHN(data.Dataset):
url = ""
filename = ""
file_md5 = ""

split_list = {
'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
"train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"],
'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
"test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"],
'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
"extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}

def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.split = split # training set or test set or extra set

if self.split not in self.split_list:
raise ValueError('Wrong split entered! Please use split="train" or split="extra" or split="test"')

self.url = self.split_list[split][0]
self.filename = self.split_list[split][1]
self.file_md5 = self.split_list[split][2]

if download:
self.download()

if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

# import here rather than at top of file because this is
# an optional dependency for torchvision
import scipy.io as sio

# reading(loading) mat file as array
loaded_mat = sio.loadmat(os.path.join(root, self.filename))

self.data = loaded_mat['X']
self.labels = loaded_mat['y']
self.data = np.transpose(self.data, (3, 2, 0, 1))

def __getitem__(self, index):
img, target = self.data[index], self.labels[index]

# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(np.transpose(img, (1, 2, 0)))

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self):
return len(self.data)

def _check_integrity(self):
import hashlib
root = self.root
md5 = self.split_list[self.split][2]
fpath = os.path.join(root, self.filename)
if not os.path.isfile(fpath):
return False
md5c = hashlib.md5(open(fpath, 'rb').read()).hexdigest()
if md5c != md5:
return False
return True

def download(self):
from six.moves import urllib
import tarfile
import hashlib

root = self.root
fpath = os.path.join(root, self.filename)

try:
os.makedirs(root)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise

if self._check_integrity():
print('Files already downloaded and verified')
return

# downloads file
if os.path.isfile(fpath):
print('Using downloaded file: ' + fpath)
else:
print('Downloading ' + self.url + ' to ' + fpath)
urllib.request.urlretrieve(self.url, fpath)
print ('Downloaded!')