You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
149 lines
5.6 KiB
149 lines
5.6 KiB
""" Dataset Factory
|
|
|
|
Hacked together by / Copyright 2021, Ross Wightman
|
|
"""
|
|
import os
|
|
|
|
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder
|
|
try:
|
|
from torchvision.datasets import Places365
|
|
has_places365 = True
|
|
except ImportError:
|
|
has_places365 = False
|
|
try:
|
|
from torchvision.datasets import INaturalist
|
|
has_inaturalist = True
|
|
except ImportError:
|
|
has_inaturalist = False
|
|
|
|
from .dataset import IterableImageDataset, ImageDataset
|
|
|
|
_TORCH_BASIC_DS = dict(
|
|
cifar10=CIFAR10,
|
|
cifar100=CIFAR100,
|
|
mnist=MNIST,
|
|
qmist=QMNIST,
|
|
kmnist=KMNIST,
|
|
fashion_mnist=FashionMNIST,
|
|
)
|
|
_TRAIN_SYNONYM = {'train', 'training'}
|
|
_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'}
|
|
|
|
|
|
def _search_split(root, split):
|
|
# look for sub-folder with name of split in root and use that if it exists
|
|
split_name = split.split('[')[0]
|
|
try_root = os.path.join(root, split_name)
|
|
if os.path.exists(try_root):
|
|
return try_root
|
|
|
|
def _try(syn):
|
|
for s in syn:
|
|
try_root = os.path.join(root, s)
|
|
if os.path.exists(try_root):
|
|
return try_root
|
|
return root
|
|
if split_name in _TRAIN_SYNONYM:
|
|
root = _try(_TRAIN_SYNONYM)
|
|
elif split_name in _EVAL_SYNONYM:
|
|
root = _try(_EVAL_SYNONYM)
|
|
return root
|
|
|
|
|
|
def create_dataset(
|
|
name,
|
|
root,
|
|
split='validation',
|
|
search_split=True,
|
|
class_map=None,
|
|
load_bytes=False,
|
|
is_training=False,
|
|
download=False,
|
|
batch_size=None,
|
|
repeats=0,
|
|
**kwargs
|
|
):
|
|
""" Dataset factory method
|
|
|
|
In parenthesis after each arg are the type of dataset supported for each arg, one of:
|
|
* folder - default, timm folder (or tar) based ImageDataset
|
|
* torch - torchvision based datasets
|
|
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
|
|
* WDS - Webdataset
|
|
* all - any of the above
|
|
|
|
Args:
|
|
name: dataset name, empty is okay for folder based datasets
|
|
root: root folder of dataset (all)
|
|
split: dataset split (all)
|
|
search_split: search for split specific child fold from root so one can specify
|
|
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
|
|
class_map: specify class -> index mapping via text file or dict (folder)
|
|
load_bytes: load data, return images as undecoded bytes (folder)
|
|
download: download dataset if not present and supported (TFDS, torch)
|
|
is_training: create dataset in train mode, this is different from the split.
|
|
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS)
|
|
batch_size: batch size hint for (TFDS)
|
|
repeats: dataset repeats per iteration i.e. epoch (TFDS)
|
|
**kwargs: other args to pass to dataset
|
|
|
|
Returns:
|
|
Dataset object
|
|
"""
|
|
name = name.lower()
|
|
if name.startswith('torch/'):
|
|
name = name.split('/', 2)[-1]
|
|
torch_kwargs = dict(root=root, download=download, **kwargs)
|
|
if name in _TORCH_BASIC_DS:
|
|
ds_class = _TORCH_BASIC_DS[name]
|
|
use_train = split in _TRAIN_SYNONYM
|
|
ds = ds_class(train=use_train, **torch_kwargs)
|
|
elif name == 'inaturalist' or name == 'inat':
|
|
assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist'
|
|
target_type = 'full'
|
|
split_split = split.split('/')
|
|
if len(split_split) > 1:
|
|
target_type = split_split[0].split('_')
|
|
if len(target_type) == 1:
|
|
target_type = target_type[0]
|
|
split = split_split[-1]
|
|
if split in _TRAIN_SYNONYM:
|
|
split = '2021_train'
|
|
elif split in _EVAL_SYNONYM:
|
|
split = '2021_valid'
|
|
ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)
|
|
elif name == 'places365':
|
|
assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.'
|
|
if split in _TRAIN_SYNONYM:
|
|
split = 'train-standard'
|
|
elif split in _EVAL_SYNONYM:
|
|
split = 'val'
|
|
ds = Places365(split=split, **torch_kwargs)
|
|
elif name == 'imagenet':
|
|
if split in _EVAL_SYNONYM:
|
|
split = 'val'
|
|
ds = ImageNet(split=split, **torch_kwargs)
|
|
elif name == 'image_folder' or name == 'folder':
|
|
# in case torchvision ImageFolder is preferred over timm ImageDataset for some reason
|
|
if search_split and os.path.isdir(root):
|
|
# look for split specific sub-folder in root
|
|
root = _search_split(root, split)
|
|
ds = ImageFolder(root, **kwargs)
|
|
else:
|
|
assert False, f"Unknown torchvision dataset {name}"
|
|
elif name.startswith('tfds/'):
|
|
ds = IterableImageDataset(
|
|
root, parser=name, split=split, is_training=is_training,
|
|
download=download, batch_size=batch_size, repeats=repeats, **kwargs)
|
|
elif name.startswith('wds/'):
|
|
ds = IterableImageDataset(
|
|
root, parser=name, split=split, is_training=is_training,
|
|
batch_size=batch_size, repeats=repeats, **kwargs)
|
|
else:
|
|
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
|
|
if search_split and os.path.isdir(root):
|
|
# look for split specific sub-folder in root
|
|
root = _search_split(root, split)
|
|
ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
|
|
return ds
|