You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
30 lines
1.0 KiB
30 lines
1.0 KiB
4 years ago
|
import os
|
||
|
|
||
|
from .dataset import IterableImageDataset, ImageDataset
|
||
|
|
||
|
|
||
|
def _search_split(root, split):
|
||
|
# look for sub-folder with name of split in root and use that if it exists
|
||
|
split_name = split.split('[')[0]
|
||
|
try_root = os.path.join(root, split_name)
|
||
|
if os.path.exists(try_root):
|
||
|
return try_root
|
||
|
if split_name == 'validation':
|
||
|
try_root = os.path.join(root, 'val')
|
||
|
if os.path.exists(try_root):
|
||
|
return try_root
|
||
|
return root
|
||
|
|
||
|
|
||
|
def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs):
|
||
|
name = name.lower()
|
||
|
if name.startswith('tfds'):
|
||
|
ds = IterableImageDataset(
|
||
|
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
|
||
|
else:
|
||
|
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
|
||
|
if search_split and os.path.isdir(root):
|
||
|
root = _search_split(root, split)
|
||
|
ds = ImageDataset(root, parser=name, **kwargs)
|
||
|
return ds
|