Initial commit for dataset / parser reorg to support additional datasets / types

pull/323/head
Ross Wightman 4 years ago
parent 392595c7eb
commit de6046e213

@ -13,7 +13,7 @@ import numpy as np
import torch import torch
from timm.models import create_model, apply_test_time_pool from timm.models import create_model, apply_test_time_pool
from timm.data import Dataset, create_loader, resolve_data_config from timm.data import ImageDataset, create_loader, resolve_data_config
from timm.utils import AverageMeter, setup_default_logging from timm.utils import AverageMeter, setup_default_logging
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@ -81,7 +81,7 @@ def main():
model = model.cuda() model = model.cuda()
loader = create_loader( loader = create_loader(
Dataset(args.data), ImageDataset(args.data),
input_size=config['input_size'], input_size=config['input_size'],
batch_size=args.batch_size, batch_size=args.batch_size,
use_prefetcher=True, use_prefetcher=True,

@ -1,6 +1,6 @@
from .constants import * from .constants import *
from .config import resolve_data_config from .config import resolve_data_config
from .dataset import Dataset, DatasetTar, AugMixDataset from .dataset import ImageDataset, AugMixDataset
from .transforms import * from .transforms import *
from .loader import create_loader from .loader import create_loader
from .transforms_factory import create_transform from .transforms_factory import create_transform

@ -2,177 +2,49 @@
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch.utils.data as data import torch.utils.data as data
import os import os
import re
import torch import torch
import tarfile
from PIL import Image from .parsers import ParserImageFolder, ParserImageTar
IMG_EXTENSIONS = ['.png', '.jpg', '.jpeg'] class ImageDataset(data.Dataset):
def natural_key(string_):
"""See http://www.codinghorror.com/blog/archives/001018.html"""
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
labels = []
filenames = []
for root, subdirs, files in os.walk(folder, topdown=False):
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
for f in files:
base, ext = os.path.splitext(f)
if ext.lower() in types:
filenames.append(os.path.join(root, f))
labels.append(label)
if class_to_idx is None:
# building class index
unique_labels = set(labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
if sort:
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
return images_and_targets, class_to_idx
def load_class_map(filename, root=''):
class_map_path = filename
if not os.path.exists(class_map_path):
class_map_path = os.path.join(root, filename)
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
class_map_ext = os.path.splitext(filename)[-1].lower()
if class_map_ext == '.txt':
with open(class_map_path) as f:
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
else:
assert False, 'Unsupported class map extension'
return class_to_idx
class Dataset(data.Dataset):
def __init__( def __init__(
self, self,
root, img_root,
parser=None,
class_map='',
load_bytes=False, load_bytes=False,
transform=None, transform=None,
class_map=''): ):
self.img_root = img_root
class_to_idx = None if parser is None:
if class_map: if os.path.isfile(img_root) and os.path.splitext(img_root)[1] == '.tar':
class_to_idx = load_class_map(class_map, root) parser = ParserImageTar(img_root, load_bytes=load_bytes, class_map=class_map)
images, class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) else:
if len(images) == 0: parser = ParserImageFolder(img_root, load_bytes=load_bytes, class_map=class_map)
raise RuntimeError(f'Found 0 images in subfolders of {root}. ' self.parser = parser
f'Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
self.root = root
self.samples = images
self.imgs = self.samples # torchvision ImageFolder compat
self.class_to_idx = class_to_idx
self.load_bytes = load_bytes self.load_bytes = load_bytes
self.transform = transform self.transform = transform
def __getitem__(self, index): def __getitem__(self, index):
path, target = self.samples[index] img, target = self.parser[index]
img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB')
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
if target is None: if target is None:
target = torch.zeros(1).long() target = torch.tensor(-1, dtype=torch.long)
return img, target return img, target
def __len__(self): def __len__(self):
return len(self.samples) return len(self.parser)
def filename(self, index, basename=False, absolute=False): def filename(self, index, basename=False, absolute=False):
filename = self.samples[index][0] return self.parser.filename(index, basename, absolute)
if basename:
filename = os.path.basename(filename)
elif not absolute:
filename = os.path.relpath(filename, self.root)
return filename
def filenames(self, basename=False, absolute=False): def filenames(self, basename=False, absolute=False):
fn = lambda x: x return self.parser.filenames(basename, absolute)
if basename:
fn = os.path.basename
elif not absolute:
fn = lambda x: os.path.relpath(x, self.root)
return [fn(x[0]) for x in self.samples]
def _extract_tar_info(tarfile, class_to_idx=None, sort=True):
files = []
labels = []
for ti in tarfile.getmembers():
if not ti.isfile():
continue
dirname, basename = os.path.split(ti.path)
label = os.path.basename(dirname)
ext = os.path.splitext(basename)[1]
if ext.lower() in IMG_EXTENSIONS:
files.append(ti)
labels.append(label)
if class_to_idx is None:
unique_labels = set(labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
if sort:
tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
return tarinfo_and_targets, class_to_idx
class DatasetTar(data.Dataset):
def __init__(self, root, load_bytes=False, transform=None, class_map=''):
class_to_idx = None
if class_map:
class_to_idx = load_class_map(class_map, root)
assert os.path.isfile(root)
self.root = root
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
self.samples, self.class_to_idx = _extract_tar_info(tf, class_to_idx)
self.imgs = self.samples
self.tarfile = None # lazy init in __getitem__
self.load_bytes = load_bytes
self.transform = transform
def __getitem__(self, index):
if self.tarfile is None:
self.tarfile = tarfile.open(self.root)
tarinfo, target = self.samples[index]
iob = self.tarfile.extractfile(tarinfo)
img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB')
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.zeros(1).long()
return img, target
def __len__(self):
return len(self.samples)
def filename(self, index, basename=False):
filename = self.samples[index][0].name
if basename:
filename = os.path.basename(filename)
return filename
def filenames(self, basename=False):
fn = os.path.basename if basename else lambda x: x
return [fn(x[0].name) for x in self.samples]
class AugMixDataset(torch.utils.data.Dataset): class AugMixDataset(torch.utils.data.Dataset):

@ -0,0 +1,4 @@
from .parser import Parser
from .parser_image_folder import ParserImageFolder
from .parser_image_tar import ParserImageTar
from .parser_in21k_tar import ParserIn21kTar

@ -0,0 +1,15 @@
def load_class_map(filename, root=''):
class_map_path = filename
if not os.path.exists(class_map_path):
class_map_path = os.path.join(root, filename)
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
class_map_ext = os.path.splitext(filename)[-1].lower()
if class_map_ext == '.txt':
with open(class_map_path) as f:
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
else:
assert False, 'Unsupported class map extension'
return class_to_idx

@ -0,0 +1,3 @@
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg')

@ -0,0 +1,17 @@
from abc import abstractmethod
class Parser:
def __init__(self):
pass
@abstractmethod
def _filename(self, index, basename=False, absolute=False):
pass
def filename(self, index, basename=False, absolute=False):
return self._filename(index, basename=basename, absolute=absolute)
def filenames(self, basename=False, absolute=False):
return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))]

@ -0,0 +1,69 @@
import os
import io
import torch
from PIL import Image
from timm.utils.misc import natural_key
from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
labels = []
filenames = []
for root, subdirs, files in os.walk(folder, topdown=False):
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
for f in files:
base, ext = os.path.splitext(f)
if ext.lower() in types:
filenames.append(os.path.join(root, f))
labels.append(label)
if class_to_idx is None:
# building class index
unique_labels = set(labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
if sort:
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
return images_and_targets, class_to_idx
class ParserImageFolder(Parser):
def __init__(
self,
root,
load_bytes=False,
class_map=''):
super().__init__()
self.root = root
self.load_bytes = load_bytes
class_to_idx = None
if class_map:
class_to_idx = load_class_map(class_map, root)
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
if len(self.samples) == 0:
raise RuntimeError(f'Found 0 images in subfolders of {root}. '
f'Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
def __getitem__(self, index):
path, target = self.samples[index]
img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB')
return img, target
def __len__(self):
return len(self.samples)
def _filename(self, index, basename=False, absolute=False):
filename = self.samples[index][0]
if basename:
filename = os.path.basename(filename)
elif not absolute:
filename = os.path.relpath(filename, self.root)
return filename

@ -0,0 +1,66 @@
import os
import io
import torch
import tarfile
from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
from PIL import Image
from timm.utils.misc import natural_key
def extract_tar_info(tarfile, class_to_idx=None, sort=True):
files = []
labels = []
for ti in tarfile.getmembers():
if not ti.isfile():
continue
dirname, basename = os.path.split(ti.path)
label = os.path.basename(dirname)
ext = os.path.splitext(basename)[1]
if ext.lower() in IMG_EXTENSIONS:
files.append(ti)
labels.append(label)
if class_to_idx is None:
unique_labels = set(labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
if sort:
tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
return tarinfo_and_targets, class_to_idx
class ParserImageTar(Parser):
def __init__(self, root, load_bytes=False, class_map=''):
super().__init__()
class_to_idx = None
if class_map:
class_to_idx = load_class_map(class_map, root)
assert os.path.isfile(root)
self.root = root
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
self.samples, self.class_to_idx = extract_tar_info(tf, class_to_idx)
self.imgs = self.samples
self.tarfile = None # lazy init in __getitem__
self.load_bytes = load_bytes
def __getitem__(self, index):
if self.tarfile is None:
self.tarfile = tarfile.open(self.root)
tarinfo, target = self.samples[index]
iob = self.tarfile.extractfile(tarinfo)
img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB')
return img, target
def __len__(self):
return len(self.samples)
def _filename(self, index, basename=False, absolute=False):
filename = self.samples[index][0].name
if basename:
filename = os.path.basename(filename)
return filename

@ -0,0 +1,104 @@
import os
import io
import re
import torch
import tarfile
import pickle
from glob import glob
import numpy as np
import torch.utils.data as data
from timm.utils.misc import natural_key
from .constants import IMG_EXTENSIONS
def load_class_map(filename, root=''):
class_map_path = filename
if not os.path.exists(class_map_path):
class_map_path = os.path.join(root, filename)
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
class_map_ext = os.path.splitext(filename)[-1].lower()
if class_map_ext == '.txt':
with open(class_map_path) as f:
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
else:
assert False, 'Unsupported class map extension'
return class_to_idx
class ParserIn21kTar(data.Dataset):
CACHE_FILENAME = 'class_info.pickle'
def __init__(self, root, class_map=''):
class_to_idx = None
if class_map:
class_to_idx = load_class_map(class_map, root)
assert os.path.isdir(root)
self.root = root
tar_filenames = glob(os.path.join(self.root, '*.tar'), recursive=True)
assert len(tar_filenames)
num_tars = len(tar_filenames)
if os.path.exists(self.CACHE_FILENAME):
with open(self.CACHE_FILENAME, 'rb') as pf:
class_info = pickle.load(pf)
else:
class_info = {}
for fi, fn in enumerate(tar_filenames):
if fi % 1000 == 0:
print(f'DEBUG: tar {fi}/{num_tars}')
# cannot keep this open across processes, reopen later
name = os.path.splitext(os.path.basename(fn))[0]
img_tarinfos = []
with tarfile.open(fn) as tf:
img_tarinfos.extend(tf.getmembers())
class_info[name] = dict(img_tarinfos=img_tarinfos)
print(f'DEBUG: {len(img_tarinfos)} images for synset {name}')
class_info = {k: v for k, v in sorted(class_info.items())}
with open('class_info.pickle', 'wb') as pf:
pickle.dump(class_info, pf, protocol=pickle.HIGHEST_PROTOCOL)
if class_to_idx is not None:
out_dict = {}
for k, v in class_info.items():
if k in class_to_idx:
class_idx = class_to_idx[k]
v['class_idx'] = class_idx
out_dict[k] = v
class_info = {k: v for k, v in sorted(out_dict.items(), key=lambda x: x[1]['class_idx'])}
else:
for i, (k, v) in enumerate(class_info.items()):
v['class_idx'] = i
self.img_infos = []
self.targets = []
self.tarnames = []
for k, v in class_info.items():
num_samples = len(v['img_tarinfos'])
self.img_infos.extend(v['img_tarinfos'])
self.targets.extend([v['class_idx']] * num_samples)
self.tarnames.extend([k] * num_samples)
self.targets = np.array(self.targets) # separate, uniform np array are more memory efficient
self.tarnames = np.array(self.tarnames)
self.tarfiles = {} # to open lazily
del class_info
def __len__(self):
return len(self.img_infos)
def __getitem__(self, idx):
img_tarinfo = self.img_infos[idx]
name = self.tarnames[idx]
tf = self.tarfiles.setdefault(name, tarfile.open(os.path.join(self.root, name + '.tar')))
img_bytes = tf.extractfile(img_tarinfo)
if self.targets:
target = self.targets[idx]
else:
target = None
return img_bytes, target

@ -28,7 +28,7 @@ import torch.nn as nn
import torchvision.utils import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model
from timm.utils import * from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
@ -275,7 +275,7 @@ def _parse_args():
def main(): def main():
setup_default_logging() setup_default_logging(log_path='./train.log')
args, args_text = _parse_args() args, args_text = _parse_args()
args.prefetcher = not args.no_prefetcher args.prefetcher = not args.no_prefetcher
@ -330,6 +330,7 @@ def main():
scriptable=args.torchscript, scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint) checkpoint_path=args.initial_checkpoint)
print(model)
if args.local_rank == 0: if args.local_rank == 0:
_logger.info('Model %s created, param count: %d' % _logger.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()]))) (args.model, sum([m.numel() for m in model.parameters()])))
@ -439,7 +440,7 @@ def main():
if not os.path.exists(train_dir): if not os.path.exists(train_dir):
_logger.error('Training folder does not exist at: {}'.format(train_dir)) _logger.error('Training folder does not exist at: {}'.format(train_dir))
exit(1) exit(1)
dataset_train = Dataset(train_dir) dataset_train = ImageDataset(train_dir)
eval_dir = os.path.join(args.data, 'val') eval_dir = os.path.join(args.data, 'val')
if not os.path.isdir(eval_dir): if not os.path.isdir(eval_dir):
@ -447,7 +448,7 @@ def main():
if not os.path.isdir(eval_dir): if not os.path.isdir(eval_dir):
_logger.error('Validation folder does not exist at: {}'.format(eval_dir)) _logger.error('Validation folder does not exist at: {}'.format(eval_dir))
exit(1) exit(1)
dataset_eval = Dataset(eval_dir) dataset_eval = ImageDataset(eval_dir)
# setup mixup / cutmix # setup mixup / cutmix
collate_fn = None collate_fn = None

@ -20,7 +20,7 @@ from collections import OrderedDict
from contextlib import suppress from contextlib import suppress
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet from timm.data import ImageDataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
has_apex = False has_apex = False
@ -157,10 +157,7 @@ def validate(args):
criterion = nn.CrossEntropyLoss().cuda() criterion = nn.CrossEntropyLoss().cuda()
if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = ImageDataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
else:
dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
if args.valid_labels: if args.valid_labels:
with open(args.valid_labels, 'r') as f: with open(args.valid_labels, 'r') as f:

Loading…
Cancel
Save