diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index 3777a5aa..92529357 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -4,7 +4,7 @@ Hacked together by / Copyright 2021, Ross Wightman """ import os -from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder +from torchvision.datasets import CIFAR100, CIFAR10, MNIST, KMNIST, FashionMNIST, ImageFolder try: from torchvision.datasets import Places365 has_places365 = True @@ -15,6 +15,16 @@ try: has_inaturalist = True except ImportError: has_inaturalist = False +try: + from torchvision.datasets import QMNIST + has_qmnist = True +except ImportError: + has_qmnist = False +try: + from torchvision.datasets import ImageNet + has_imagenet = True +except ImportError: + has_imagenet = False from .dataset import IterableImageDataset, ImageDataset @@ -22,7 +32,6 @@ _TORCH_BASIC_DS = dict( cifar10=CIFAR10, cifar100=CIFAR100, mnist=MNIST, - qmist=QMNIST, kmnist=KMNIST, fashion_mnist=FashionMNIST, )