|
|
|
@ -1,6 +1,7 @@
|
|
|
|
|
import csv
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
from .dataset import IterableImageDataset, ImageDataset
|
|
|
|
|
from .dataset import IterableImageDataset, ImageDataset, COAIImageClassDataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _search_split(root, split):
|
|
|
|
@ -21,6 +22,10 @@ def create_dataset(name, root, split='validation', search_split=True, is_trainin
|
|
|
|
|
if name.startswith('tfds'):
|
|
|
|
|
ds = IterableImageDataset(
|
|
|
|
|
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
|
|
|
|
|
elif name.startswith('coaiclass'):
|
|
|
|
|
# Get Dict from csv(current implementation)/mongodb(needs to be added)
|
|
|
|
|
dict = _get_dict_from_csv(root)
|
|
|
|
|
ds = COAIImageClassDataset(dict=dict)
|
|
|
|
|
else:
|
|
|
|
|
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
|
|
|
|
|
kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier
|
|
|
|
@ -28,3 +33,15 @@ def create_dataset(name, root, split='validation', search_split=True, is_trainin
|
|
|
|
|
root = _search_split(root, split)
|
|
|
|
|
ds = ImageDataset(root, parser=name, **kwargs)
|
|
|
|
|
return ds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert(lst):
|
|
|
|
|
res_dct = {lst[i][0]: lst[i][1] for i in range(len(lst))}
|
|
|
|
|
return res_dct
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_dict_from_csv(data_folder):
|
|
|
|
|
with open(data_folder + '/train.csv', 'r') as f:
|
|
|
|
|
reader = csv.reader(f)
|
|
|
|
|
data = [row for row in reader]
|
|
|
|
|
return _convert(data)
|