diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index a6f2af387dc..c5d4feecdf6 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -35,6 +35,11 @@ Fashion-MNIST .. autoclass:: FashionMNIST +KMNIST +~~~~~~~~~~~~~ + +.. autoclass:: KMNIST + EMNIST ~~~~~~ diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 6ab5722315d..22cb5e9405d 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -3,7 +3,7 @@ from .coco import CocoCaptions, CocoDetection from .cifar import CIFAR10, CIFAR100 from .stl10 import STL10 -from .mnist import MNIST, EMNIST, FashionMNIST +from .mnist import MNIST, EMNIST, FashionMNIST, KMNIST from .svhn import SVHN from .phototour import PhotoTour from .fakedata import FakeData @@ -17,6 +17,6 @@ 'ImageFolder', 'DatasetFolder', 'FakeData', 'CocoCaptions', 'CocoDetection', 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', - 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', + 'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', 'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k', 'VOCSegmentation', 'VOCDetection') diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 90649178a95..675fe5edc7f 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -179,6 +179,31 @@ class FashionMNIST(MNIST): 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] +class KMNIST(MNIST): + """`Kuzushiji-MNIST `_ Dataset. + + Args: + root (string): Root directory of dataset where ``processed/training.pt`` + and ``processed/test.pt`` exist. + train (bool, optional): If True, creates dataset from ``training.pt``, + otherwise from ``test.pt``. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + urls = [ + 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz', + 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz', + 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz', + 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz', + ] + classes = ['o', 'ki', 'su', 'tsu', 'na', 'ha', 'ma', 'ya', 're', 'wo'] + + class EMNIST(MNIST): """`EMNIST `_ Dataset.