Merge pull request #1564 from HongxinXiang/main

Fix compatible BUG: QMNIST and ImageNet datasets do not exist in torc…
pull/1520/head
Ross Wightman 2 years ago committed by GitHub
commit 98c5e78226
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,7 +4,7 @@ Hacked together by / Copyright 2021, Ross Wightman
"""
import os
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, KMNIST, FashionMNIST, ImageFolder
try:
from torchvision.datasets import Places365
has_places365 = True
@ -15,6 +15,16 @@ try:
has_inaturalist = True
except ImportError:
has_inaturalist = False
try:
from torchvision.datasets import QMNIST
has_qmnist = True
except ImportError:
has_qmnist = False
try:
from torchvision.datasets import ImageNet
has_imagenet = True
except ImportError:
has_imagenet = False
from .dataset import IterableImageDataset, ImageDataset
@ -22,7 +32,6 @@ _TORCH_BASIC_DS = dict(
cifar10=CIFAR10,
cifar100=CIFAR100,
mnist=MNIST,
qmist=QMNIST,
kmnist=KMNIST,
fashion_mnist=FashionMNIST,
)
@ -122,7 +131,12 @@ def create_dataset(
elif split in _EVAL_SYNONYM:
split = 'val'
ds = Places365(split=split, **torch_kwargs)
elif name == 'qmnist':
assert has_qmnist, 'Please update to a newer PyTorch and torchvision for QMNIST dataset.'
use_train = split in _TRAIN_SYNONYM
ds = QMNIST(train=use_train, **torch_kwargs)
elif name == 'imagenet':
assert has_imagenet, 'Please update to a newer PyTorch and torchvision for ImageNet dataset.'
if split in _EVAL_SYNONYM:
split = 'val'
ds = ImageNet(split=split, **torch_kwargs)

Loading…
Cancel
Save