-
Notifications
You must be signed in to change notification settings - Fork 7.2k
SVHN dataset for torchvision #98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Replacing scipy.misc.lena() with scipy.misc.face()
Syncing forked repository
Adding SVHN dataset http://ufldl.stanford.edu/housenumbers/ (Format 2) for torchvision
|
@uridah do you have a notebook that checks for sanity of this code, like the notebook that loads some CIFAR10 images and displays them: |
|
@soumith I don't but I can create one. |
|
I don't know if PyTorch follows PEP8 as a styleguide? If so, you way want to run flake8 on your code and fix the issues that finds; I see a few whitespace issues in it. :-) If not, I'll leave it to the project owners to set the requirements. |
|
thanks @mjpieters . We have a LINT check as part of the contbuild, and it is failing: https://travis-ci.org/pytorch/vision/builds/209884377 Locally doing: will show exact LINT errors you need to fix @uridah |
torchvision/datasets/svhn.py
Outdated
| raise RuntimeError('Dataset not found or corrupted.' + | ||
| ' You can use download=True to download it') | ||
|
|
||
| self.train_data = [] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/datasets/svhn.py
Outdated
| self.dataset = dataset # training set or test set or extra set | ||
|
|
||
| # download and load the data | ||
| if self.dataset=='train': |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/datasets/svhn.py
Outdated
|
|
||
| def __len__(self): | ||
| if self.dataset == 'train': | ||
| return 73257 |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/datasets/svhn.py
Outdated
| self.root = root | ||
| self.transform = transform | ||
| self.target_transform = target_transform | ||
| self.dataset = dataset # training set or test set or extra set |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@uridah I made some small inline comments (very minor and are only style changes). The PR looks good, thanks! |
- now using dictionary for urls, filenames and md5s - updated len function - renamed 'dataset' keyword to split - fixed whitespaces using flake8
|
@fmassa Your comments very very useful and really helped concise the code |
torchvision/datasets/svhn.py
Outdated
| # reading(loading) mat file as array | ||
| loaded_mat = sio.loadmat(os.path.join(root, self.filename)) | ||
|
|
||
| if self.split != 'test': |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/datasets/svhn.py
Outdated
| self.test_labels = loaded_mat['y'] | ||
| self.test_data = np.transpose(self.test_data, (3, 2, 1, 0)) | ||
| else: | ||
| print ("Wrong dataset entered! Please use split=train or split=extra or split=test") |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/datasets/svhn.py
Outdated
| self.target_transform = target_transform | ||
| self.split = split # training set or test set or extra set | ||
|
|
||
| if self.split in self.split_list: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/datasets/svhn.py
Outdated
| print ("Wrong dataset entered! Please use split=train or split=extra or split=test") | ||
|
|
||
| def __getitem__(self, index): | ||
| if self.split == 'train' or self.split == 'extra': |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/datasets/svhn.py
Outdated
| return img, target | ||
|
|
||
| def __len__(self): | ||
| return len(self.train_data) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Hi @uridah |
|
Thanks @fmassa. Updated according to your suggestions |
torchvision/datasets/svhn.py
Outdated
|
|
||
| if self.split not in self.split_list: | ||
| raise ValueError('Wrong split entered! Please use split=train or split=extra or split=test') | ||
| else: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/datasets/svhn.py
Outdated
| self.download() | ||
|
|
||
| if not self._check_integrity(): | ||
| raise RuntimeError('Dataset not found or corrupted.' + |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
About the |
|
I will, if you guys think it's good enough. |
|
@uridah another thing, it would be great if you could add an entry in the doc in |
|
Hi @uridah |
|
@fmassa as it turns out I needed it to transpose along 3,2,0,1 axis instead of 3,2,1,0. I have fixed that and updated the indentation and sanity_checks1.ipynb. Please have a look and let me know if anything else needs to be changed. |
|
Thanks Uridah, as you saw the last 4 commits, I made some minor changes to your PR. But it looked great. Merged into master now!!! |
|
Hello, It seems that the labels returned are in the range Given that some of the loss functions (CELoss, NLLLoss) expect the class labels to be in the range Another thing is that the the returned I can make a PR if the current behaviour is inconsistent and should be made similar to the other data sets, for example as in CIFAR10 |
|
@vabh if you could make a PR to reflect this, that'd be great. Thanks. |
…mlperf/community (pytorch#98)
In reference to #59
cc: @mjpieters