|
|
@ -1,7 +1,11 @@
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
|
|
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST,\
|
|
|
|
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder
|
|
|
|
Places365, ImageNet, ImageFolder
|
|
|
|
try:
|
|
|
|
|
|
|
|
from torchvision.datasets import Places365
|
|
|
|
|
|
|
|
has_places365 = True
|
|
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
|
|
has_places365 = False
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
from torchvision.datasets import INaturalist
|
|
|
|
from torchvision.datasets import INaturalist
|
|
|
|
has_inaturalist = True
|
|
|
|
has_inaturalist = True
|
|
|
@ -104,6 +108,7 @@ def create_dataset(
|
|
|
|
split = '2021_valid'
|
|
|
|
split = '2021_valid'
|
|
|
|
ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)
|
|
|
|
ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)
|
|
|
|
elif name == 'places365':
|
|
|
|
elif name == 'places365':
|
|
|
|
|
|
|
|
assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.'
|
|
|
|
if split in _TRAIN_SYNONYM:
|
|
|
|
if split in _TRAIN_SYNONYM:
|
|
|
|
split = 'train-standard'
|
|
|
|
split = 'train-standard'
|
|
|
|
elif split in _EVAL_SYNONYM:
|
|
|
|
elif split in _EVAL_SYNONYM:
|
|
|
|