You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/data/dataset_factory.py

144 lines
5.4 KiB

""" Dataset Factory
Hacked together by / Copyright 2021, Ross Wightman
"""
import os
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder
try:
from torchvision.datasets import Places365
has_places365 = True
except ImportError:
has_places365 = False
try:
from torchvision.datasets import INaturalist
has_inaturalist = True
except ImportError:
has_inaturalist = False
from .dataset import IterableImageDataset, ImageDataset
_TORCH_BASIC_DS = dict(
cifar10=CIFAR10,
cifar100=CIFAR100,
mnist=MNIST,
qmist=QMNIST,
kmnist=KMNIST,
fashion_mnist=FashionMNIST,
)
_TRAIN_SYNONYM = {'train', 'training'}
_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'}
def _search_split(root, split):
# look for sub-folder with name of split in root and use that if it exists
split_name = split.split('[')[0]
try_root = os.path.join(root, split_name)
if os.path.exists(try_root):
return try_root
def _try(syn):
for s in syn:
try_root = os.path.join(root, s)
if os.path.exists(try_root):
return try_root
return root
if split_name in _TRAIN_SYNONYM:
root = _try(_TRAIN_SYNONYM)
elif split_name in _EVAL_SYNONYM:
root = _try(_EVAL_SYNONYM)
return root
def create_dataset(
name,
root,
split='validation',
search_split=True,
class_map=None,
load_bytes=False,
is_training=False,
download=False,
batch_size=None,
repeats=0,
**kwargs
):
""" Dataset factory method
In parenthesis after each arg are the type of dataset supported for each arg, one of:
* folder - default, timm folder (or tar) based ImageDataset
* torch - torchvision based datasets
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
* all - any of the above
Args:
name: dataset name, empty is okay for folder based datasets
root: root folder of dataset (all)
split: dataset split (all)
search_split: search for split specific child fold from root so one can specify
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
class_map: specify class -> index mapping via text file or dict (folder)
load_bytes: load data, return images as undecoded bytes (folder)
download: download dataset if not present and supported (TFDS, torch)
is_training: create dataset in train mode, this is different from the split.
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS)
batch_size: batch size hint for (TFDS)
repeats: dataset repeats per iteration i.e. epoch (TFDS)
**kwargs: other args to pass to dataset
Returns:
Dataset object
"""
name = name.lower()
if name.startswith('torch/'):
name = name.split('/', 2)[-1]
torch_kwargs = dict(root=root, download=download, **kwargs)
if name in _TORCH_BASIC_DS:
ds_class = _TORCH_BASIC_DS[name]
use_train = split in _TRAIN_SYNONYM
ds = ds_class(train=use_train, **torch_kwargs)
elif name == 'inaturalist' or name == 'inat':
assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist'
target_type = 'full'
split_split = split.split('/')
if len(split_split) > 1:
target_type = split_split[0].split('_')
if len(target_type) == 1:
target_type = target_type[0]
split = split_split[-1]
if split in _TRAIN_SYNONYM:
split = '2021_train'
elif split in _EVAL_SYNONYM:
split = '2021_valid'
ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)
elif name == 'places365':
assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.'
if split in _TRAIN_SYNONYM:
split = 'train-standard'
elif split in _EVAL_SYNONYM:
split = 'val'
ds = Places365(split=split, **torch_kwargs)
elif name == 'imagenet':
if split in _EVAL_SYNONYM:
split = 'val'
ds = ImageNet(split=split, **torch_kwargs)
elif name == 'image_folder' or name == 'folder':
# in case torchvision ImageFolder is preferred over timm ImageDataset for some reason
if search_split and os.path.isdir(root):
# look for split specific sub-folder in root
root = _search_split(root, split)
ds = ImageFolder(root, **kwargs)
else:
assert False, f"Unknown torchvision dataset {name}"
elif name.startswith('tfds/'):
ds = IterableImageDataset(
root, parser=name, split=split, is_training=is_training,
download=download, batch_size=batch_size, repeats=repeats, **kwargs)
else:
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
if search_split and os.path.isdir(root):
# look for split specific sub-folder in root
root = _search_split(root, split)
ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
return ds