|
|
@ -4,7 +4,7 @@ Hacked together by / Copyright 2021, Ross Wightman
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
|
|
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder
|
|
|
|
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, KMNIST, FashionMNIST, ImageFolder
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
from torchvision.datasets import Places365
|
|
|
|
from torchvision.datasets import Places365
|
|
|
|
has_places365 = True
|
|
|
|
has_places365 = True
|
|
|
@ -15,6 +15,16 @@ try:
|
|
|
|
has_inaturalist = True
|
|
|
|
has_inaturalist = True
|
|
|
|
except ImportError:
|
|
|
|
except ImportError:
|
|
|
|
has_inaturalist = False
|
|
|
|
has_inaturalist = False
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
from torchvision.datasets import QMNIST
|
|
|
|
|
|
|
|
has_qmnist = True
|
|
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
|
|
has_qmnist = False
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
from torchvision.datasets import ImageNet
|
|
|
|
|
|
|
|
has_imagenet = True
|
|
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
|
|
has_imagenet = False
|
|
|
|
|
|
|
|
|
|
|
|
from .dataset import IterableImageDataset, ImageDataset
|
|
|
|
from .dataset import IterableImageDataset, ImageDataset
|
|
|
|
|
|
|
|
|
|
|
@ -22,7 +32,6 @@ _TORCH_BASIC_DS = dict(
|
|
|
|
cifar10=CIFAR10,
|
|
|
|
cifar10=CIFAR10,
|
|
|
|
cifar100=CIFAR100,
|
|
|
|
cifar100=CIFAR100,
|
|
|
|
mnist=MNIST,
|
|
|
|
mnist=MNIST,
|
|
|
|
qmist=QMNIST,
|
|
|
|
|
|
|
|
kmnist=KMNIST,
|
|
|
|
kmnist=KMNIST,
|
|
|
|
fashion_mnist=FashionMNIST,
|
|
|
|
fashion_mnist=FashionMNIST,
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -122,7 +131,12 @@ def create_dataset(
|
|
|
|
elif split in _EVAL_SYNONYM:
|
|
|
|
elif split in _EVAL_SYNONYM:
|
|
|
|
split = 'val'
|
|
|
|
split = 'val'
|
|
|
|
ds = Places365(split=split, **torch_kwargs)
|
|
|
|
ds = Places365(split=split, **torch_kwargs)
|
|
|
|
|
|
|
|
elif name == 'qmnist':
|
|
|
|
|
|
|
|
assert has_qmnist, 'Please update to a newer PyTorch and torchvision for QMNIST dataset.'
|
|
|
|
|
|
|
|
use_train = split in _TRAIN_SYNONYM
|
|
|
|
|
|
|
|
ds = QMNIST(train=use_train, **torch_kwargs)
|
|
|
|
elif name == 'imagenet':
|
|
|
|
elif name == 'imagenet':
|
|
|
|
|
|
|
|
assert has_imagenet, 'Please update to a newer PyTorch and torchvision for ImageNet dataset.'
|
|
|
|
if split in _EVAL_SYNONYM:
|
|
|
|
if split in _EVAL_SYNONYM:
|
|
|
|
split = 'val'
|
|
|
|
split = 'val'
|
|
|
|
ds = ImageNet(split=split, **torch_kwargs)
|
|
|
|
ds = ImageNet(split=split, **torch_kwargs)
|
|
|
|