From 855d6cc2171f309cd819e8ec1cafd117685b1be8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 15 Jan 2021 17:26:20 -0800 Subject: [PATCH] More dataset work including factories and a tensorflow datasets (TFDS) wrapper * Add parser/dataset factory methods for more flexible dataset & parser creation * Add dataset parser that wraps TFDS image classification datasets * Tweak num_classes handling bug for 21k models * Add initial deit models so they can be benchmarked in next csv results runs --- timm/data/__init__.py | 14 +- timm/data/dataset.py | 52 ++++++- timm/data/dataset_factory.py | 29 ++++ timm/data/loader.py | 17 ++- timm/data/parsers/__init__.py | 5 +- timm/data/parsers/parser_factory.py | 29 ++++ timm/data/parsers/parser_tfds.py | 201 ++++++++++++++++++++++++++++ timm/models/helpers.py | 6 +- timm/models/resnetv2.py | 12 +- timm/models/vision_transformer.py | 84 ++++++++++-- train.py | 35 +++-- validate.py | 12 +- 12 files changed, 431 insertions(+), 65 deletions(-) create mode 100644 timm/data/dataset_factory.py create mode 100644 timm/data/parsers/parser_factory.py create mode 100644 timm/data/parsers/parser_tfds.py diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 1dd8ac57..7d3cb2b4 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -1,10 +1,12 @@ -from .constants import * +from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ + rand_augment_transform, auto_augment_transform from .config import resolve_data_config -from .dataset import ImageDataset, AugMixDataset -from .transforms import * +from .constants import * +from .dataset import ImageDataset, IterableImageDataset, AugMixDataset +from .dataset_factory import create_dataset from .loader import create_loader -from .transforms_factory import create_transform from .mixup import Mixup, FastCollateMixup -from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ - rand_augment_transform, auto_augment_transform +from .parsers import create_parser from .real_labels import RealLabelsImagenet +from .transforms import * +from .transforms_factory import create_transform \ No newline at end of file diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 42a46eef..a7c5ebed 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -9,7 +9,7 @@ import logging from PIL import Image -from .parsers import ParserImageFolder, ParserImageTar, ParserImageClassInTar +from .parsers import create_parser _logger = logging.getLogger(__name__) @@ -27,11 +27,8 @@ class ImageDataset(data.Dataset): load_bytes=False, transform=None, ): - if parser is None: - if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': - parser = ParserImageTar(root, class_map=class_map) - else: - parser = ParserImageFolder(root, class_map=class_map) + if parser is None or isinstance(parser, str): + parser = create_parser(parser or '', root=root, class_map=class_map) self.parser = parser self.load_bytes = load_bytes self.transform = transform @@ -65,6 +62,49 @@ class ImageDataset(data.Dataset): return self.parser.filenames(basename, absolute) +class IterableImageDataset(data.IterableDataset): + + def __init__( + self, + root, + parser=None, + split='train', + is_training=False, + batch_size=None, + class_map='', + load_bytes=False, + transform=None, + ): + assert parser is not None + if isinstance(parser, str): + self.parser = create_parser( + parser, root=root, split=split, is_training=is_training, batch_size=batch_size) + else: + self.parser = parser + self.transform = transform + self._consecutive_errors = 0 + + def __iter__(self): + for img, target in self.parser: + if self.transform is not None: + img = self.transform(img) + if target is None: + target = torch.tensor(-1, dtype=torch.long) + yield img, target + + def __len__(self): + if hasattr(self.parser, '__len__'): + return len(self.parser) + else: + return 0 + + def filename(self, index, basename=False, absolute=False): + assert False, 'Filename lookup by index not supported, use filenames().' + + def filenames(self, basename=False, absolute=False): + return self.parser.filenames(basename, absolute) + + class AugMixDataset(torch.utils.data.Dataset): """Dataset wrapper to perform AugMix or other clean/augmentation mixes""" diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py new file mode 100644 index 00000000..b2c9688f --- /dev/null +++ b/timm/data/dataset_factory.py @@ -0,0 +1,29 @@ +import os + +from .dataset import IterableImageDataset, ImageDataset + + +def _search_split(root, split): + # look for sub-folder with name of split in root and use that if it exists + split_name = split.split('[')[0] + try_root = os.path.join(root, split_name) + if os.path.exists(try_root): + return try_root + if split_name == 'validation': + try_root = os.path.join(root, 'val') + if os.path.exists(try_root): + return try_root + return root + + +def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs): + name = name.lower() + if name.startswith('tfds'): + ds = IterableImageDataset( + root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs) + else: + # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future + if search_split and os.path.isdir(root): + root = _search_split(root, split) + ds = ImageDataset(root, parser=name, **kwargs) + return ds diff --git a/timm/data/loader.py b/timm/data/loader.py index 317f77df..76144669 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -153,7 +153,8 @@ def create_loader( pin_memory=False, fp16=False, tf_preprocessing=False, - use_multi_epochs_loader=False + use_multi_epochs_loader=False, + persistent_workers=True, ): re_num_splits = 0 if re_split: @@ -183,7 +184,7 @@ def create_loader( ) sampler = None - if distributed: + if distributed and not isinstance(dataset, torch.utils.data.IterableDataset): if is_training: sampler = torch.utils.data.distributed.DistributedSampler(dataset) else: @@ -199,16 +200,20 @@ def create_loader( if use_multi_epochs_loader: loader_class = MultiEpochsDataLoader - loader = loader_class( - dataset, + loader_args = dict( batch_size=batch_size, - shuffle=sampler is None and is_training, + shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training, num_workers=num_workers, sampler=sampler, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=is_training, - ) + persistent_workers=persistent_workers) + try: + loader = loader_class(dataset, **loader_args) + except TypeError as e: + loader_args.pop('persistent_workers') # only in Pytorch 1.7+ + loader = loader_class(dataset, **loader_args) if use_prefetcher: prefetch_re_prob = re_prob if is_training and not no_aug else 0. loader = PrefetchLoader( diff --git a/timm/data/parsers/__init__.py b/timm/data/parsers/__init__.py index 4ecb3a22..eeb44e37 100644 --- a/timm/data/parsers/__init__.py +++ b/timm/data/parsers/__init__.py @@ -1,4 +1 @@ -from .parser import Parser -from .parser_image_folder import ParserImageFolder -from .parser_image_tar import ParserImageTar -from .parser_image_class_in_tar import ParserImageClassInTar \ No newline at end of file +from .parser_factory import create_parser diff --git a/timm/data/parsers/parser_factory.py b/timm/data/parsers/parser_factory.py new file mode 100644 index 00000000..ce9aa35f --- /dev/null +++ b/timm/data/parsers/parser_factory.py @@ -0,0 +1,29 @@ +import os + +from .parser_image_folder import ParserImageFolder +from .parser_image_tar import ParserImageTar +from .parser_image_class_in_tar import ParserImageClassInTar + + +def create_parser(name, root, split='train', **kwargs): + name = name.lower() + name = name.split('/', 2) + prefix = '' + if len(name) > 1: + prefix = name[0] + name = name[-1] + + # FIXME improve the selection right now just tfds prefix or fallback path, will need options to + # explicitly select other options shortly + if prefix == 'tfds': + from .parser_tfds import ParserTfds # defer tensorflow import + parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs) + else: + assert os.path.exists(root) + # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder + # FIXME support split here, in parser? + if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': + parser = ParserImageTar(root, **kwargs) + else: + parser = ParserImageFolder(root, **kwargs) + return parser diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py new file mode 100644 index 00000000..39a9243a --- /dev/null +++ b/timm/data/parsers/parser_tfds.py @@ -0,0 +1,201 @@ +""" Dataset parser interface that wraps TFDS datasets + +Wraps many (most?) TFDS image-classification datasets +from https://github.com/tensorflow/datasets +https://www.tensorflow.org/datasets/catalog/overview#image_classification + +Hacked together by / Copyright 2020 Ross Wightman +""" +import os +import io +import math +import torch +import torch.distributed as dist +from PIL import Image + +try: + import tensorflow as tf + tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu) + import tensorflow_datasets as tfds +except ImportError as e: + print(e) + print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.") + exit(1) +from .parser import Parser + + +MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities +SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue +PREFETCH_SIZE = 4096 # samples to prefetch + + +class ParserTfds(Parser): + """ Wrap Tensorflow Datasets for use in PyTorch + + There several things to be aware of: + * To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of + dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last + https://github.com/pytorch/pytorch/issues/33413 + * With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch + from each worker could be a different size. For training this is avoid by option above, for + validation extra samples are inserted iff distributed mode is enabled so the batches being reduced + across replicas are of same size. This will slightlyalter the results, distributed validation will not be + 100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse + since there are to N * J extra samples. + * The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of + replicas and dataloader workers you can use. For really small datasets that only contain a few shards + you may have to train non-distributed w/ 1-2 dataloader workers. This may not be a huge concern as the + benefit of distributed training or fast dataloading should be much less for small datasets. + * This wrapper is currently configured to return individual, decompressed image samples from the TFDS + dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible + to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream + components. + + """ + def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None): + super().__init__() + self.root = root + self.split = split + self.shuffle = shuffle + self.is_training = is_training + if self.is_training: + assert batch_size is not None,\ + "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" + self.batch_size = batch_size + + self.builder = tfds.builder(name, data_dir=root) + # NOTE: please use tfds command line app to download & prepare datasets, I don't want to trigger + # it by default here as it's caused issues generating unwanted paths in data directories. + self.num_samples = self.builder.info.splits[split].num_examples + self.ds = None # initialized lazily on each dataloader worker process + + self.worker_info = None + self.dist_rank = 0 + self.dist_num_replicas = 1 + if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: + self.dist_rank = dist.get_rank() + self.dist_num_replicas = dist.get_world_size() + + def _lazy_init(self): + """ Lazily initialize the dataset. + + This is necessary to init the Tensorflow dataset pipeline in the (dataloader) process that + will be using the dataset instance. The __init__ method is called on the main process, + this will be called in a dataloader worker process. + + NOTE: There will be problems if you try to re-use this dataset across different loader/worker + instances once it has been initialized. Do not call any dataset methods that can call _lazy_init + before it is passed to dataloader. + """ + worker_info = torch.utils.data.get_worker_info() + + # setup input context to split dataset across distributed processes + split = self.split + num_workers = 1 + if worker_info is not None: + self.worker_info = worker_info + num_workers = worker_info.num_workers + worker_id = worker_info.id + + # FIXME I need to spend more time figuring out the best way to distribute/split data across + # combo of distributed replicas + dataloader worker processes + """ + InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. + My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True) + between the splits each iteration but that could be wrong. + Possible split options include: + * InputContext for both distributed & worker processes (current) + * InputContext for distributed and sub-splits for worker processes + * sub-splits for both + """ + # split_size = self.num_samples // num_workers + # start = worker_id * split_size + # if worker_id == num_workers - 1: + # split = split + '[{}:]'.format(start) + # else: + # split = split + '[{}:{}]'.format(start, start + split_size) + + input_context = tf.distribute.InputContext( + num_input_pipelines=self.dist_num_replicas * num_workers, + input_pipeline_id=self.dist_rank * num_workers + worker_id, + num_replicas_in_sync=self.dist_num_replicas # FIXME does this have any impact? + ) + + read_config = tfds.ReadConfig(input_context=input_context) + ds = self.builder.as_dataset(split=split, shuffle_files=self.shuffle, read_config=read_config) + # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers + ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) + ds.options().experimental_threading.max_intra_op_parallelism = 1 + if self.is_training: + # to prevent excessive drop_last batch behaviour w/ IterableDatasets + # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading + ds = ds.repeat() # allow wrap around and break iteration manually + if self.shuffle: + ds = ds.shuffle(min(self.num_samples // self._num_pipelines, SHUFFLE_SIZE), seed=0) + ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE)) + self.ds = tfds.as_numpy(ds) + + def __iter__(self): + if self.ds is None: + self._lazy_init() + # compute a rounded up sample count that is used to: + # 1. make batches even cross workers & replicas in distributed validation. + # This adds extra samples and will slightly alter validation results. + # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size + # batches are produced (underlying tfds iter wraps around) + target_sample_count = math.ceil(self.num_samples / self._num_pipelines) + if self.is_training: + # round up to nearest batch_size per worker-replica + target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size + sample_count = 0 + for sample in self.ds: + img = Image.fromarray(sample['image'], mode='RGB') + yield img, sample['label'] + sample_count += 1 + if self.is_training and sample_count >= target_sample_count: + # Need to break out of loop when repeat() is enabled for training w/ oversampling + # this results in 'extra' samples per epoch but seems more desirable than dropping + # up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes) + break + if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count: + # Validation batch padding only done for distributed training where results are reduced across nodes. + # For single process case, it won't matter if workers return different batch sizes. + # FIXME this needs more testing, possible for sharding / split api to cause differences of > 1? + assert target_sample_count - sample_count == 1 # should only be off by 1 or sharding is not optimal + yield img, sample['label'] # yield prev sample again + sample_count += 1 + + @property + def _num_workers(self): + return 1 if self.worker_info is None else self.worker_info.num_workers + + @property + def _num_pipelines(self): + return self._num_workers * self.dist_num_replicas + + def __len__(self): + # this is just an estimate and does not factor in extra samples added to pad batches based on + # complete worker & replica info (not available until init in dataloader). + return math.ceil(self.num_samples / self.dist_num_replicas) + + def _filename(self, index, basename=False, absolute=False): + assert False, "Not supported" # no random access to samples + + def filenames(self, basename=False, absolute=False): + """ Return all filenames in dataset, overrides base""" + if self.ds is None: + self._lazy_init() + names = [] + for sample in self.ds: + if len(names) > self.num_samples: + break # safety for ds.repeat() case + if 'file_name' in sample: + name = sample['file_name'] + elif 'filename' in sample: + name = sample['filename'] + elif 'id' in sample: + name = sample['id'] + else: + assert False, "No supported name field present" + names.append(name) + return names diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 2a15e528..96f551e3 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -11,7 +11,11 @@ from typing import Callable import torch import torch.nn as nn -from torch.hub import get_dir, load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX +from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX +try: + from torch.hub import get_dir +except ImportError: + from torch.hub import _get_torch_home as get_dir from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .layers import Conv2dSame, Linear diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 731f5dca..f51d6357 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -507,42 +507,42 @@ def resnetv2_152x4_bitm(pretrained=False, **kwargs): @register_model def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs): return _create_resnetv2( - 'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843), + 'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs) @register_model def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs): return _create_resnetv2( - 'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843), + 'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs) @register_model def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs): return _create_resnetv2( - 'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843), + 'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs) @register_model def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs): return _create_resnetv2( - 'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843), + 'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs) @register_model def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs): return _create_resnetv2( - 'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843), + 'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs) @register_model def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs): return _create_resnetv2( - 'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843), + 'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index ff5bd676..076010ab 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -5,12 +5,6 @@ A PyTorch implement of Vision Transformers as described in The official jax code is released and available at https://github.com/google-research/vision_transformer -Status/TODO: -* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights. -* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches. -* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code. -* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future. - Acknowledgments: * The paper authors for releasing code and weights, thanks! * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out @@ -18,6 +12,9 @@ for some einops/einsum fun * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT * Bert reference code checks against Huggingface Transformers and Tensorflow Bert +DeiT model defs and weights from https://github.com/facebookresearch/deit, +paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + Hacked together by / Copyright 2020 Ross Wightman """ import torch @@ -50,7 +47,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', ), - # patch models (weights ported from official JAX impl) + # patch models (weights ported from official Google JAX impl) 'vit_base_patch16_224': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), @@ -77,7 +74,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), - # patch models, imagenet21k (weights ported from official JAX impl) + # patch models, imagenet21k (weights ported from official Google JAX impl) 'vit_base_patch16_224_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth', num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), @@ -94,7 +91,7 @@ default_cfgs = { url='', # FIXME I have weights for this but > 2GB limit for github release binaries num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), - # hybrid models (weights ported from official JAX impl) + # hybrid models (weights ported from official Google JAX impl) 'vit_base_resnet50_224_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9), @@ -107,6 +104,17 @@ default_cfgs = { 'vit_small_resnet50d_s3_224': _cfg(), 'vit_base_resnet26d_224': _cfg(), 'vit_base_resnet50d_224': _cfg(), + + # deit models (FB weights) + 'deit_tiny_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), + 'deit_small_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'), + 'deit_base_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',), + 'deit_base_patch16_384': _cfg( + url='', # no weights yet + input_size=(3, 384, 384)), } @@ -433,7 +441,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs): @register_model def vit_base_patch16_224_in21k(pretrained=False, **kwargs): - num_classes = kwargs.get('num_classes', 21843) + num_classes = kwargs.pop('num_classes', 21843) model = VisionTransformer( patch_size=16, num_classes=num_classes, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, representation_size=768, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) @@ -446,7 +454,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs): @register_model def vit_base_patch32_224_in21k(pretrained=False, **kwargs): - num_classes = kwargs.get('num_classes', 21843) + num_classes = kwargs.pop('num_classes', 21843) model = VisionTransformer( img_size=224, num_classes=num_classes, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, representation_size=768, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) @@ -458,7 +466,7 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs): @register_model def vit_large_patch16_224_in21k(pretrained=False, **kwargs): - num_classes = kwargs.get('num_classes', 21843) + num_classes = kwargs.pop('num_classes', 21843) model = VisionTransformer( patch_size=16, num_classes=num_classes, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, representation_size=1024, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) @@ -482,7 +490,7 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs): @register_model def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): - num_classes = kwargs.get('num_classes', 21843) + num_classes = kwargs.pop('num_classes', 21843) model = VisionTransformer( img_size=224, patch_size=14, num_classes=num_classes, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, representation_size=1280, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) @@ -495,7 +503,7 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): @register_model def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head - num_classes = kwargs.get('num_classes', 21843) + num_classes = kwargs.pop('num_classes', 21843) backbone = ResNetV2( layers=(3, 4, 9), preact=False, stem_type='same', conv_layer=StdConv2dSame, num_classes=0, global_pool='') model = VisionTransformer( @@ -559,3 +567,51 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs): img_size=224, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs) model.default_cfg = default_cfgs['vit_base_resnet50d_224'] return model + + +@register_model +def deit_tiny_patch16_224(pretrained=False, **kwargs): + model = VisionTransformer( + patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['deit_tiny_patch16_224'] + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model']) + return model + + +@register_model +def deit_small_patch16_224(pretrained=False, **kwargs): + model = VisionTransformer( + patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['deit_small_patch16_224'] + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model']) + return model + + +@register_model +def deit_base_patch16_224(pretrained=False, **kwargs): + model = VisionTransformer( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['deit_base_patch16_224'] + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model']) + return model + + +@register_model +def deit_base_patch16_384(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['deit_base_patch16_384'] + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model']) + return model diff --git a/train.py b/train.py index 4bb68399..b31199f9 100755 --- a/train.py +++ b/train.py @@ -28,7 +28,7 @@ import torch.nn as nn import torchvision.utils from torch.nn.parallel import DistributedDataParallel as NativeDDP -from timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset +from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy @@ -64,8 +64,14 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') # Dataset / Model parameters -parser.add_argument('data', metavar='DIR', +parser.add_argument('data_dir', metavar='DIR', help='path to dataset') +parser.add_argument('--dataset', '-d', metavar='NAME', default='', + help='dataset type (default: ImageFolder/ImageTar if empty)') +parser.add_argument('--train-split', metavar='NAME', default='train', + help='dataset train split (default: train)') +parser.add_argument('--val-split', metavar='NAME', default='validation', + help='dataset validation split (default: validation)') parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL', help='Name of model to train (default: "countception"') parser.add_argument('--pretrained', action='store_true', default=False, @@ -437,19 +443,10 @@ def main(): _logger.info('Scheduled epochs: {}'.format(num_epochs)) # create the train and eval datasets - train_dir = os.path.join(args.data, 'train') - if not os.path.exists(train_dir): - _logger.error('Training folder does not exist at: {}'.format(train_dir)) - exit(1) - dataset_train = ImageDataset(train_dir) - - eval_dir = os.path.join(args.data, 'val') - if not os.path.isdir(eval_dir): - eval_dir = os.path.join(args.data, 'validation') - if not os.path.isdir(eval_dir): - _logger.error('Validation folder does not exist at: {}'.format(eval_dir)) - exit(1) - dataset_eval = ImageDataset(eval_dir) + dataset_train = create_dataset( + args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size) + dataset_eval = create_dataset( + args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) # setup mixup / cutmix collate_fn = None @@ -553,10 +550,10 @@ def main(): try: for epoch in range(start_epoch, num_epochs): - if args.distributed: - loader_train.sampler.set_epoch(epoch) + if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): + loader_train.set_epoch(epoch) - train_metrics = train_epoch( + train_metrics = train_one_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) @@ -594,7 +591,7 @@ def main(): _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) -def train_epoch( +def train_one_epoch( epoch, model, loader, optimizer, loss_fn, args, lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress, loss_scaler=None, model_ema=None, mixup_fn=None): diff --git a/validate.py b/validate.py index d9ba377c..be977cc2 100755 --- a/validate.py +++ b/validate.py @@ -20,7 +20,7 @@ from collections import OrderedDict from contextlib import suppress from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models -from timm.data import ImageDataset, create_loader, resolve_data_config, RealLabelsImagenet +from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy has_apex = False @@ -44,7 +44,11 @@ _logger = logging.getLogger('validate') parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') parser.add_argument('data', metavar='DIR', help='path to dataset') -parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92', +parser.add_argument('--dataset', '-d', metavar='NAME', default='', + help='dataset type (default: ImageFolder/ImageTar if empty)') +parser.add_argument('--split', metavar='NAME', default='validation', + help='dataset split (default: validation)') +parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', help='model architecture (default: dpn92)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 2)') @@ -159,7 +163,9 @@ def validate(args): criterion = nn.CrossEntropyLoss().cuda() - dataset = ImageDataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) + dataset = create_dataset( + root=args.data, name=args.dataset, split=args.split, + load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: