diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index 03b03cf5..e86bcc29 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -1,7 +1,11 @@ import os -from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST,\ - Places365, ImageNet, ImageFolder +from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder +try: + from torchvision.datasets import Places365 + has_places365 = True +except ImportError: + has_places365 = False try: from torchvision.datasets import INaturalist has_inaturalist = True @@ -104,6 +108,7 @@ def create_dataset( split = '2021_valid' ds = INaturalist(version=split, target_type=target_type, **torch_kwargs) elif name == 'places365': + assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.' if split in _TRAIN_SYNONYM: split = 'train-standard' elif split in _EVAL_SYNONYM: