From de6046e213087e9c282e557d5286a060bee6e594 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 8 Dec 2020 17:03:00 -0800 Subject: [PATCH] Initial commit for dataset / parser reorg to support additional datasets / types --- inference.py | 4 +- timm/data/__init__.py | 2 +- timm/data/dataset.py | 170 +++-------------------- timm/data/parsers/__init__.py | 4 + timm/data/parsers/class_map.py | 15 ++ timm/data/parsers/constants.py | 3 + timm/data/parsers/parser.py | 17 +++ timm/data/parsers/parser_image_folder.py | 69 +++++++++ timm/data/parsers/parser_image_tar.py | 66 +++++++++ timm/data/parsers/parser_in21k_tar.py | 104 ++++++++++++++ train.py | 9 +- validate.py | 7 +- 12 files changed, 309 insertions(+), 161 deletions(-) create mode 100644 timm/data/parsers/__init__.py create mode 100644 timm/data/parsers/class_map.py create mode 100644 timm/data/parsers/constants.py create mode 100644 timm/data/parsers/parser.py create mode 100644 timm/data/parsers/parser_image_folder.py create mode 100644 timm/data/parsers/parser_image_tar.py create mode 100644 timm/data/parsers/parser_in21k_tar.py diff --git a/inference.py b/inference.py index 16d19944..f7ee6d3e 100755 --- a/inference.py +++ b/inference.py @@ -13,7 +13,7 @@ import numpy as np import torch from timm.models import create_model, apply_test_time_pool -from timm.data import Dataset, create_loader, resolve_data_config +from timm.data import ImageDataset, create_loader, resolve_data_config from timm.utils import AverageMeter, setup_default_logging torch.backends.cudnn.benchmark = True @@ -81,7 +81,7 @@ def main(): model = model.cuda() loader = create_loader( - Dataset(args.data), + ImageDataset(args.data), input_size=config['input_size'], batch_size=args.batch_size, use_prefetcher=True, diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 15617859..1dd8ac57 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -1,6 +1,6 @@ from .constants import * from .config import resolve_data_config -from .dataset import Dataset, DatasetTar, AugMixDataset +from .dataset import ImageDataset, AugMixDataset from .transforms import * from .loader import create_loader from .transforms_factory import create_transform diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 99d99917..8013c846 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -2,177 +2,49 @@ Hacked together by / Copyright 2020 Ross Wightman """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import torch.utils.data as data - import os -import re import torch -import tarfile -from PIL import Image - - -IMG_EXTENSIONS = ['.png', '.jpg', '.jpeg'] - - -def natural_key(string_): - """See http://www.codinghorror.com/blog/archives/001018.html""" - return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] - - -def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): - labels = [] - filenames = [] - for root, subdirs, files in os.walk(folder, topdown=False): - rel_path = os.path.relpath(root, folder) if (root != folder) else '' - label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') - for f in files: - base, ext = os.path.splitext(f) - if ext.lower() in types: - filenames.append(os.path.join(root, f)) - labels.append(label) - if class_to_idx is None: - # building class index - unique_labels = set(labels) - sorted_labels = list(sorted(unique_labels, key=natural_key)) - class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} - images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx] - if sort: - images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) - return images_and_targets, class_to_idx - - -def load_class_map(filename, root=''): - class_map_path = filename - if not os.path.exists(class_map_path): - class_map_path = os.path.join(root, filename) - assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename - class_map_ext = os.path.splitext(filename)[-1].lower() - if class_map_ext == '.txt': - with open(class_map_path) as f: - class_to_idx = {v.strip(): k for k, v in enumerate(f)} - else: - assert False, 'Unsupported class map extension' - return class_to_idx - - -class Dataset(data.Dataset): + +from .parsers import ParserImageFolder, ParserImageTar + + +class ImageDataset(data.Dataset): def __init__( self, - root, + img_root, + parser=None, + class_map='', load_bytes=False, transform=None, - class_map=''): - - class_to_idx = None - if class_map: - class_to_idx = load_class_map(class_map, root) - images, class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) - if len(images) == 0: - raise RuntimeError(f'Found 0 images in subfolders of {root}. ' - f'Supported image extensions are {", ".join(IMG_EXTENSIONS)}') - self.root = root - self.samples = images - self.imgs = self.samples # torchvision ImageFolder compat - self.class_to_idx = class_to_idx + ): + self.img_root = img_root + if parser is None: + if os.path.isfile(img_root) and os.path.splitext(img_root)[1] == '.tar': + parser = ParserImageTar(img_root, load_bytes=load_bytes, class_map=class_map) + else: + parser = ParserImageFolder(img_root, load_bytes=load_bytes, class_map=class_map) + self.parser = parser self.load_bytes = load_bytes self.transform = transform def __getitem__(self, index): - path, target = self.samples[index] - img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB') + img, target = self.parser[index] if self.transform is not None: img = self.transform(img) if target is None: - target = torch.zeros(1).long() + target = torch.tensor(-1, dtype=torch.long) return img, target def __len__(self): - return len(self.samples) + return len(self.parser) def filename(self, index, basename=False, absolute=False): - filename = self.samples[index][0] - if basename: - filename = os.path.basename(filename) - elif not absolute: - filename = os.path.relpath(filename, self.root) - return filename + return self.parser.filename(index, basename, absolute) def filenames(self, basename=False, absolute=False): - fn = lambda x: x - if basename: - fn = os.path.basename - elif not absolute: - fn = lambda x: os.path.relpath(x, self.root) - return [fn(x[0]) for x in self.samples] - - -def _extract_tar_info(tarfile, class_to_idx=None, sort=True): - files = [] - labels = [] - for ti in tarfile.getmembers(): - if not ti.isfile(): - continue - dirname, basename = os.path.split(ti.path) - label = os.path.basename(dirname) - ext = os.path.splitext(basename)[1] - if ext.lower() in IMG_EXTENSIONS: - files.append(ti) - labels.append(label) - if class_to_idx is None: - unique_labels = set(labels) - sorted_labels = list(sorted(unique_labels, key=natural_key)) - class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} - tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx] - if sort: - tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) - return tarinfo_and_targets, class_to_idx - - -class DatasetTar(data.Dataset): - - def __init__(self, root, load_bytes=False, transform=None, class_map=''): - - class_to_idx = None - if class_map: - class_to_idx = load_class_map(class_map, root) - assert os.path.isfile(root) - self.root = root - with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later - self.samples, self.class_to_idx = _extract_tar_info(tf, class_to_idx) - self.imgs = self.samples - self.tarfile = None # lazy init in __getitem__ - self.load_bytes = load_bytes - self.transform = transform - - def __getitem__(self, index): - if self.tarfile is None: - self.tarfile = tarfile.open(self.root) - tarinfo, target = self.samples[index] - iob = self.tarfile.extractfile(tarinfo) - img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB') - if self.transform is not None: - img = self.transform(img) - if target is None: - target = torch.zeros(1).long() - return img, target - - def __len__(self): - return len(self.samples) - - def filename(self, index, basename=False): - filename = self.samples[index][0].name - if basename: - filename = os.path.basename(filename) - return filename - - def filenames(self, basename=False): - fn = os.path.basename if basename else lambda x: x - return [fn(x[0].name) for x in self.samples] + return self.parser.filenames(basename, absolute) class AugMixDataset(torch.utils.data.Dataset): diff --git a/timm/data/parsers/__init__.py b/timm/data/parsers/__init__.py new file mode 100644 index 00000000..c502eec8 --- /dev/null +++ b/timm/data/parsers/__init__.py @@ -0,0 +1,4 @@ +from .parser import Parser +from .parser_image_folder import ParserImageFolder +from .parser_image_tar import ParserImageTar +from .parser_in21k_tar import ParserIn21kTar \ No newline at end of file diff --git a/timm/data/parsers/class_map.py b/timm/data/parsers/class_map.py new file mode 100644 index 00000000..f5fa7e2a --- /dev/null +++ b/timm/data/parsers/class_map.py @@ -0,0 +1,15 @@ + + +def load_class_map(filename, root=''): + class_map_path = filename + if not os.path.exists(class_map_path): + class_map_path = os.path.join(root, filename) + assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename + class_map_ext = os.path.splitext(filename)[-1].lower() + if class_map_ext == '.txt': + with open(class_map_path) as f: + class_to_idx = {v.strip(): k for k, v in enumerate(f)} + else: + assert False, 'Unsupported class map extension' + return class_to_idx + diff --git a/timm/data/parsers/constants.py b/timm/data/parsers/constants.py new file mode 100644 index 00000000..6e3be34b --- /dev/null +++ b/timm/data/parsers/constants.py @@ -0,0 +1,3 @@ +IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') + + diff --git a/timm/data/parsers/parser.py b/timm/data/parsers/parser.py new file mode 100644 index 00000000..76ab6d18 --- /dev/null +++ b/timm/data/parsers/parser.py @@ -0,0 +1,17 @@ +from abc import abstractmethod + + +class Parser: + def __init__(self): + pass + + @abstractmethod + def _filename(self, index, basename=False, absolute=False): + pass + + def filename(self, index, basename=False, absolute=False): + return self._filename(index, basename=basename, absolute=absolute) + + def filenames(self, basename=False, absolute=False): + return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))] + diff --git a/timm/data/parsers/parser_image_folder.py b/timm/data/parsers/parser_image_folder.py new file mode 100644 index 00000000..8a61007f --- /dev/null +++ b/timm/data/parsers/parser_image_folder.py @@ -0,0 +1,69 @@ +import os +import io +import torch + +from PIL import Image +from timm.utils.misc import natural_key + +from .parser import Parser +from .class_map import load_class_map +from .constants import IMG_EXTENSIONS + + +def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): + labels = [] + filenames = [] + for root, subdirs, files in os.walk(folder, topdown=False): + rel_path = os.path.relpath(root, folder) if (root != folder) else '' + label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') + for f in files: + base, ext = os.path.splitext(f) + if ext.lower() in types: + filenames.append(os.path.join(root, f)) + labels.append(label) + if class_to_idx is None: + # building class index + unique_labels = set(labels) + sorted_labels = list(sorted(unique_labels, key=natural_key)) + class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} + images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx] + if sort: + images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) + return images_and_targets, class_to_idx + + +class ParserImageFolder(Parser): + + def __init__( + self, + root, + load_bytes=False, + class_map=''): + super().__init__() + + self.root = root + self.load_bytes = load_bytes + + class_to_idx = None + if class_map: + class_to_idx = load_class_map(class_map, root) + self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) + if len(self.samples) == 0: + raise RuntimeError(f'Found 0 images in subfolders of {root}. ' + f'Supported image extensions are {", ".join(IMG_EXTENSIONS)}') + + def __getitem__(self, index): + path, target = self.samples[index] + img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB') + return img, target + + def __len__(self): + return len(self.samples) + + def _filename(self, index, basename=False, absolute=False): + filename = self.samples[index][0] + if basename: + filename = os.path.basename(filename) + elif not absolute: + filename = os.path.relpath(filename, self.root) + return filename diff --git a/timm/data/parsers/parser_image_tar.py b/timm/data/parsers/parser_image_tar.py new file mode 100644 index 00000000..504e71e8 --- /dev/null +++ b/timm/data/parsers/parser_image_tar.py @@ -0,0 +1,66 @@ +import os +import io +import torch +import tarfile + +from .parser import Parser +from .class_map import load_class_map +from .constants import IMG_EXTENSIONS +from PIL import Image +from timm.utils.misc import natural_key + + +def extract_tar_info(tarfile, class_to_idx=None, sort=True): + files = [] + labels = [] + for ti in tarfile.getmembers(): + if not ti.isfile(): + continue + dirname, basename = os.path.split(ti.path) + label = os.path.basename(dirname) + ext = os.path.splitext(basename)[1] + if ext.lower() in IMG_EXTENSIONS: + files.append(ti) + labels.append(label) + if class_to_idx is None: + unique_labels = set(labels) + sorted_labels = list(sorted(unique_labels, key=natural_key)) + class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} + tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx] + if sort: + tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) + return tarinfo_and_targets, class_to_idx + + +class ParserImageTar(Parser): + + def __init__(self, root, load_bytes=False, class_map=''): + super().__init__() + + class_to_idx = None + if class_map: + class_to_idx = load_class_map(class_map, root) + assert os.path.isfile(root) + self.root = root + with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later + self.samples, self.class_to_idx = extract_tar_info(tf, class_to_idx) + self.imgs = self.samples + self.tarfile = None # lazy init in __getitem__ + self.load_bytes = load_bytes + + def __getitem__(self, index): + if self.tarfile is None: + self.tarfile = tarfile.open(self.root) + tarinfo, target = self.samples[index] + iob = self.tarfile.extractfile(tarinfo) + img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB') + return img, target + + def __len__(self): + return len(self.samples) + + def _filename(self, index, basename=False, absolute=False): + filename = self.samples[index][0].name + if basename: + filename = os.path.basename(filename) + return filename diff --git a/timm/data/parsers/parser_in21k_tar.py b/timm/data/parsers/parser_in21k_tar.py new file mode 100644 index 00000000..da7e9d26 --- /dev/null +++ b/timm/data/parsers/parser_in21k_tar.py @@ -0,0 +1,104 @@ +import os +import io +import re +import torch +import tarfile +import pickle +from glob import glob +import numpy as np + +import torch.utils.data as data + +from timm.utils.misc import natural_key + +from .constants import IMG_EXTENSIONS + + +def load_class_map(filename, root=''): + class_map_path = filename + if not os.path.exists(class_map_path): + class_map_path = os.path.join(root, filename) + assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename + class_map_ext = os.path.splitext(filename)[-1].lower() + if class_map_ext == '.txt': + with open(class_map_path) as f: + class_to_idx = {v.strip(): k for k, v in enumerate(f)} + else: + assert False, 'Unsupported class map extension' + return class_to_idx + + +class ParserIn21kTar(data.Dataset): + + CACHE_FILENAME = 'class_info.pickle' + + def __init__(self, root, class_map=''): + + class_to_idx = None + if class_map: + class_to_idx = load_class_map(class_map, root) + assert os.path.isdir(root) + self.root = root + tar_filenames = glob(os.path.join(self.root, '*.tar'), recursive=True) + assert len(tar_filenames) + num_tars = len(tar_filenames) + + if os.path.exists(self.CACHE_FILENAME): + with open(self.CACHE_FILENAME, 'rb') as pf: + class_info = pickle.load(pf) + else: + class_info = {} + for fi, fn in enumerate(tar_filenames): + if fi % 1000 == 0: + print(f'DEBUG: tar {fi}/{num_tars}') + # cannot keep this open across processes, reopen later + name = os.path.splitext(os.path.basename(fn))[0] + img_tarinfos = [] + with tarfile.open(fn) as tf: + img_tarinfos.extend(tf.getmembers()) + class_info[name] = dict(img_tarinfos=img_tarinfos) + print(f'DEBUG: {len(img_tarinfos)} images for synset {name}') + class_info = {k: v for k, v in sorted(class_info.items())} + + with open('class_info.pickle', 'wb') as pf: + pickle.dump(class_info, pf, protocol=pickle.HIGHEST_PROTOCOL) + + if class_to_idx is not None: + out_dict = {} + for k, v in class_info.items(): + if k in class_to_idx: + class_idx = class_to_idx[k] + v['class_idx'] = class_idx + out_dict[k] = v + class_info = {k: v for k, v in sorted(out_dict.items(), key=lambda x: x[1]['class_idx'])} + else: + for i, (k, v) in enumerate(class_info.items()): + v['class_idx'] = i + + self.img_infos = [] + self.targets = [] + self.tarnames = [] + for k, v in class_info.items(): + num_samples = len(v['img_tarinfos']) + self.img_infos.extend(v['img_tarinfos']) + self.targets.extend([v['class_idx']] * num_samples) + self.tarnames.extend([k] * num_samples) + self.targets = np.array(self.targets) # separate, uniform np array are more memory efficient + self.tarnames = np.array(self.tarnames) + + self.tarfiles = {} # to open lazily + del class_info + + def __len__(self): + return len(self.img_infos) + + def __getitem__(self, idx): + img_tarinfo = self.img_infos[idx] + name = self.tarnames[idx] + tf = self.tarfiles.setdefault(name, tarfile.open(os.path.join(self.root, name + '.tar'))) + img_bytes = tf.extractfile(img_tarinfo) + if self.targets: + target = self.targets[idx] + else: + target = None + return img_bytes, target diff --git a/train.py b/train.py index 7a93a1b6..ca406655 100755 --- a/train.py +++ b/train.py @@ -28,7 +28,7 @@ import torch.nn as nn import torchvision.utils from torch.nn.parallel import DistributedDataParallel as NativeDDP -from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset +from timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy @@ -275,7 +275,7 @@ def _parse_args(): def main(): - setup_default_logging() + setup_default_logging(log_path='./train.log') args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher @@ -330,6 +330,7 @@ def main(): scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) + print(model) if args.local_rank == 0: _logger.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) @@ -439,7 +440,7 @@ def main(): if not os.path.exists(train_dir): _logger.error('Training folder does not exist at: {}'.format(train_dir)) exit(1) - dataset_train = Dataset(train_dir) + dataset_train = ImageDataset(train_dir) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): @@ -447,7 +448,7 @@ def main(): if not os.path.isdir(eval_dir): _logger.error('Validation folder does not exist at: {}'.format(eval_dir)) exit(1) - dataset_eval = Dataset(eval_dir) + dataset_eval = ImageDataset(eval_dir) # setup mixup / cutmix collate_fn = None diff --git a/validate.py b/validate.py index 5a0d388c..645dfd1d 100755 --- a/validate.py +++ b/validate.py @@ -20,7 +20,7 @@ from collections import OrderedDict from contextlib import suppress from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models -from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet +from timm.data import ImageDataset, create_loader, resolve_data_config, RealLabelsImagenet from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy has_apex = False @@ -157,10 +157,7 @@ def validate(args): criterion = nn.CrossEntropyLoss().cuda() - if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): - dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) - else: - dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) + dataset = ImageDataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: