Merge pull request #1564 from HongxinXiang/main

Fix compatible BUG: QMNIST and ImageNet datasets do not exist in torc…
pull/1520/head
Ross Wightman 2 years ago committed by GitHub
commit 98c5e78226
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,7 +4,7 @@ Hacked together by / Copyright 2021, Ross Wightman
""" """
import os import os
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder from torchvision.datasets import CIFAR100, CIFAR10, MNIST, KMNIST, FashionMNIST, ImageFolder
try: try:
from torchvision.datasets import Places365 from torchvision.datasets import Places365
has_places365 = True has_places365 = True
@ -15,6 +15,16 @@ try:
has_inaturalist = True has_inaturalist = True
except ImportError: except ImportError:
has_inaturalist = False has_inaturalist = False
try:
from torchvision.datasets import QMNIST
has_qmnist = True
except ImportError:
has_qmnist = False
try:
from torchvision.datasets import ImageNet
has_imagenet = True
except ImportError:
has_imagenet = False
from .dataset import IterableImageDataset, ImageDataset from .dataset import IterableImageDataset, ImageDataset
@ -22,7 +32,6 @@ _TORCH_BASIC_DS = dict(
cifar10=CIFAR10, cifar10=CIFAR10,
cifar100=CIFAR100, cifar100=CIFAR100,
mnist=MNIST, mnist=MNIST,
qmist=QMNIST,
kmnist=KMNIST, kmnist=KMNIST,
fashion_mnist=FashionMNIST, fashion_mnist=FashionMNIST,
) )
@ -122,7 +131,12 @@ def create_dataset(
elif split in _EVAL_SYNONYM: elif split in _EVAL_SYNONYM:
split = 'val' split = 'val'
ds = Places365(split=split, **torch_kwargs) ds = Places365(split=split, **torch_kwargs)
elif name == 'qmnist':
assert has_qmnist, 'Please update to a newer PyTorch and torchvision for QMNIST dataset.'
use_train = split in _TRAIN_SYNONYM
ds = QMNIST(train=use_train, **torch_kwargs)
elif name == 'imagenet': elif name == 'imagenet':
assert has_imagenet, 'Please update to a newer PyTorch and torchvision for ImageNet dataset.'
if split in _EVAL_SYNONYM: if split in _EVAL_SYNONYM:
split = 'val' split = 'val'
ds = ImageNet(split=split, **torch_kwargs) ds = ImageNet(split=split, **torch_kwargs)

Loading…
Cancel
Save