More dataset work including factories and a tensorflow datasets (TFDS) wrapper

* Add parser/dataset factory methods for more flexible dataset & parser creation
* Add dataset parser that wraps TFDS image classification datasets
* Tweak num_classes handling bug for 21k models
* Add initial deit models so they can be benchmarked in next csv results runs
pull/323/head
Ross Wightman 3 years ago
parent 20516abc18
commit 855d6cc217

@ -1,10 +1,12 @@
from .constants import *
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform
from .config import resolve_data_config
from .dataset import ImageDataset, AugMixDataset
from .transforms import *
from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset
from .loader import create_loader
from .transforms_factory import create_transform
from .mixup import Mixup, FastCollateMixup
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform
from .parsers import create_parser
from .real_labels import RealLabelsImagenet
from .transforms import *
from .transforms_factory import create_transform

@ -9,7 +9,7 @@ import logging
from PIL import Image
from .parsers import ParserImageFolder, ParserImageTar, ParserImageClassInTar
from .parsers import create_parser
_logger = logging.getLogger(__name__)
@ -27,11 +27,8 @@ class ImageDataset(data.Dataset):
load_bytes=False,
transform=None,
):
if parser is None:
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
parser = ParserImageTar(root, class_map=class_map)
else:
parser = ParserImageFolder(root, class_map=class_map)
if parser is None or isinstance(parser, str):
parser = create_parser(parser or '', root=root, class_map=class_map)
self.parser = parser
self.load_bytes = load_bytes
self.transform = transform
@ -65,6 +62,49 @@ class ImageDataset(data.Dataset):
return self.parser.filenames(basename, absolute)
class IterableImageDataset(data.IterableDataset):
def __init__(
self,
root,
parser=None,
split='train',
is_training=False,
batch_size=None,
class_map='',
load_bytes=False,
transform=None,
):
assert parser is not None
if isinstance(parser, str):
self.parser = create_parser(
parser, root=root, split=split, is_training=is_training, batch_size=batch_size)
else:
self.parser = parser
self.transform = transform
self._consecutive_errors = 0
def __iter__(self):
for img, target in self.parser:
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
yield img, target
def __len__(self):
if hasattr(self.parser, '__len__'):
return len(self.parser)
else:
return 0
def filename(self, index, basename=False, absolute=False):
assert False, 'Filename lookup by index not supported, use filenames().'
def filenames(self, basename=False, absolute=False):
return self.parser.filenames(basename, absolute)
class AugMixDataset(torch.utils.data.Dataset):
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""

@ -0,0 +1,29 @@
import os
from .dataset import IterableImageDataset, ImageDataset
def _search_split(root, split):
# look for sub-folder with name of split in root and use that if it exists
split_name = split.split('[')[0]
try_root = os.path.join(root, split_name)
if os.path.exists(try_root):
return try_root
if split_name == 'validation':
try_root = os.path.join(root, 'val')
if os.path.exists(try_root):
return try_root
return root
def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs):
name = name.lower()
if name.startswith('tfds'):
ds = IterableImageDataset(
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
else:
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
if search_split and os.path.isdir(root):
root = _search_split(root, split)
ds = ImageDataset(root, parser=name, **kwargs)
return ds

@ -153,7 +153,8 @@ def create_loader(
pin_memory=False,
fp16=False,
tf_preprocessing=False,
use_multi_epochs_loader=False
use_multi_epochs_loader=False,
persistent_workers=True,
):
re_num_splits = 0
if re_split:
@ -183,7 +184,7 @@ def create_loader(
)
sampler = None
if distributed:
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
if is_training:
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else:
@ -199,16 +200,20 @@ def create_loader(
if use_multi_epochs_loader:
loader_class = MultiEpochsDataLoader
loader = loader_class(
dataset,
loader_args = dict(
batch_size=batch_size,
shuffle=sampler is None and is_training,
shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
num_workers=num_workers,
sampler=sampler,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=is_training,
)
persistent_workers=persistent_workers)
try:
loader = loader_class(dataset, **loader_args)
except TypeError as e:
loader_args.pop('persistent_workers') # only in Pytorch 1.7+
loader = loader_class(dataset, **loader_args)
if use_prefetcher:
prefetch_re_prob = re_prob if is_training and not no_aug else 0.
loader = PrefetchLoader(

@ -1,4 +1 @@
from .parser import Parser
from .parser_image_folder import ParserImageFolder
from .parser_image_tar import ParserImageTar
from .parser_image_class_in_tar import ParserImageClassInTar
from .parser_factory import create_parser

@ -0,0 +1,29 @@
import os
from .parser_image_folder import ParserImageFolder
from .parser_image_tar import ParserImageTar
from .parser_image_class_in_tar import ParserImageClassInTar
def create_parser(name, root, split='train', **kwargs):
name = name.lower()
name = name.split('/', 2)
prefix = ''
if len(name) > 1:
prefix = name[0]
name = name[-1]
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
# explicitly select other options shortly
if prefix == 'tfds':
from .parser_tfds import ParserTfds # defer tensorflow import
parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs)
else:
assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
# FIXME support split here, in parser?
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
parser = ParserImageTar(root, **kwargs)
else:
parser = ParserImageFolder(root, **kwargs)
return parser

@ -0,0 +1,201 @@
""" Dataset parser interface that wraps TFDS datasets
Wraps many (most?) TFDS image-classification datasets
from https://github.com/tensorflow/datasets
https://www.tensorflow.org/datasets/catalog/overview#image_classification
Hacked together by / Copyright 2020 Ross Wightman
"""
import os
import io
import math
import torch
import torch.distributed as dist
from PIL import Image
try:
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
import tensorflow_datasets as tfds
except ImportError as e:
print(e)
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
exit(1)
from .parser import Parser
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue
PREFETCH_SIZE = 4096 # samples to prefetch
class ParserTfds(Parser):
""" Wrap Tensorflow Datasets for use in PyTorch
There several things to be aware of:
* To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of
dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last
https://github.com/pytorch/pytorch/issues/33413
* With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch
from each worker could be a different size. For training this is avoid by option above, for
validation extra samples are inserted iff distributed mode is enabled so the batches being reduced
across replicas are of same size. This will slightlyalter the results, distributed validation will not be
100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse
since there are to N * J extra samples.
* The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of
replicas and dataloader workers you can use. For really small datasets that only contain a few shards
you may have to train non-distributed w/ 1-2 dataloader workers. This may not be a huge concern as the
benefit of distributed training or fast dataloading should be much less for small datasets.
* This wrapper is currently configured to return individual, decompressed image samples from the TFDS
dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible
to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream
components.
"""
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None):
super().__init__()
self.root = root
self.split = split
self.shuffle = shuffle
self.is_training = is_training
if self.is_training:
assert batch_size is not None,\
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
self.batch_size = batch_size
self.builder = tfds.builder(name, data_dir=root)
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to trigger
# it by default here as it's caused issues generating unwanted paths in data directories.
self.num_samples = self.builder.info.splits[split].num_examples
self.ds = None # initialized lazily on each dataloader worker process
self.worker_info = None
self.dist_rank = 0
self.dist_num_replicas = 1
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
self.dist_rank = dist.get_rank()
self.dist_num_replicas = dist.get_world_size()
def _lazy_init(self):
""" Lazily initialize the dataset.
This is necessary to init the Tensorflow dataset pipeline in the (dataloader) process that
will be using the dataset instance. The __init__ method is called on the main process,
this will be called in a dataloader worker process.
NOTE: There will be problems if you try to re-use this dataset across different loader/worker
instances once it has been initialized. Do not call any dataset methods that can call _lazy_init
before it is passed to dataloader.
"""
worker_info = torch.utils.data.get_worker_info()
# setup input context to split dataset across distributed processes
split = self.split
num_workers = 1
if worker_info is not None:
self.worker_info = worker_info
num_workers = worker_info.num_workers
worker_id = worker_info.id
# FIXME I need to spend more time figuring out the best way to distribute/split data across
# combo of distributed replicas + dataloader worker processes
"""
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
between the splits each iteration but that could be wrong.
Possible split options include:
* InputContext for both distributed & worker processes (current)
* InputContext for distributed and sub-splits for worker processes
* sub-splits for both
"""
# split_size = self.num_samples // num_workers
# start = worker_id * split_size
# if worker_id == num_workers - 1:
# split = split + '[{}:]'.format(start)
# else:
# split = split + '[{}:{}]'.format(start, start + split_size)
input_context = tf.distribute.InputContext(
num_input_pipelines=self.dist_num_replicas * num_workers,
input_pipeline_id=self.dist_rank * num_workers + worker_id,
num_replicas_in_sync=self.dist_num_replicas # FIXME does this have any impact?
)
read_config = tfds.ReadConfig(input_context=input_context)
ds = self.builder.as_dataset(split=split, shuffle_files=self.shuffle, read_config=read_config)
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
ds.options().experimental_threading.max_intra_op_parallelism = 1
if self.is_training:
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually
if self.shuffle:
ds = ds.shuffle(min(self.num_samples // self._num_pipelines, SHUFFLE_SIZE), seed=0)
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
self.ds = tfds.as_numpy(ds)
def __iter__(self):
if self.ds is None:
self._lazy_init()
# compute a rounded up sample count that is used to:
# 1. make batches even cross workers & replicas in distributed validation.
# This adds extra samples and will slightly alter validation results.
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
# batches are produced (underlying tfds iter wraps around)
target_sample_count = math.ceil(self.num_samples / self._num_pipelines)
if self.is_training:
# round up to nearest batch_size per worker-replica
target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size
sample_count = 0
for sample in self.ds:
img = Image.fromarray(sample['image'], mode='RGB')
yield img, sample['label']
sample_count += 1
if self.is_training and sample_count >= target_sample_count:
# Need to break out of loop when repeat() is enabled for training w/ oversampling
# this results in 'extra' samples per epoch but seems more desirable than dropping
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
break
if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count:
# Validation batch padding only done for distributed training where results are reduced across nodes.
# For single process case, it won't matter if workers return different batch sizes.
# FIXME this needs more testing, possible for sharding / split api to cause differences of > 1?
assert target_sample_count - sample_count == 1 # should only be off by 1 or sharding is not optimal
yield img, sample['label'] # yield prev sample again
sample_count += 1
@property
def _num_workers(self):
return 1 if self.worker_info is None else self.worker_info.num_workers
@property
def _num_pipelines(self):
return self._num_workers * self.dist_num_replicas
def __len__(self):
# this is just an estimate and does not factor in extra samples added to pad batches based on
# complete worker & replica info (not available until init in dataloader).
return math.ceil(self.num_samples / self.dist_num_replicas)
def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to samples
def filenames(self, basename=False, absolute=False):
""" Return all filenames in dataset, overrides base"""
if self.ds is None:
self._lazy_init()
names = []
for sample in self.ds:
if len(names) > self.num_samples:
break # safety for ds.repeat() case
if 'file_name' in sample:
name = sample['file_name']
elif 'filename' in sample:
name = sample['filename']
elif 'id' in sample:
name = sample['id']
else:
assert False, "No supported name field present"
names.append(name)
return names

@ -11,7 +11,11 @@ from typing import Callable
import torch
import torch.nn as nn
from torch.hub import get_dir, load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX
from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX
try:
from torch.hub import get_dir
except ImportError:
from torch.hub import _get_torch_home as get_dir
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .layers import Conv2dSame, Linear

@ -507,42 +507,42 @@ def resnetv2_152x4_bitm(pretrained=False, **kwargs):
@register_model
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
@register_model
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
@register_model
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
@register_model
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
@register_model
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
@register_model
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)

@ -5,12 +5,6 @@ A PyTorch implement of Vision Transformers as described in
The official jax code is released and available at https://github.com/google-research/vision_transformer
Status/TODO:
* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
Acknowledgments:
* The paper authors for releasing code and weights, thanks!
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
@ -18,6 +12,9 @@ for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
DeiT model defs and weights from https://github.com/facebookresearch/deit,
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
@ -50,7 +47,7 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
),
# patch models (weights ported from official JAX impl)
# patch models (weights ported from official Google JAX impl)
'vit_base_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
@ -77,7 +74,7 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
# patch models, imagenet21k (weights ported from official JAX impl)
# patch models, imagenet21k (weights ported from official Google JAX impl)
'vit_base_patch16_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
@ -94,7 +91,7 @@ default_cfgs = {
url='', # FIXME I have weights for this but > 2GB limit for github release binaries
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
# hybrid models (weights ported from official JAX impl)
# hybrid models (weights ported from official Google JAX impl)
'vit_base_resnet50_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9),
@ -107,6 +104,17 @@ default_cfgs = {
'vit_small_resnet50d_s3_224': _cfg(),
'vit_base_resnet26d_224': _cfg(),
'vit_base_resnet50d_224': _cfg(),
# deit models (FB weights)
'deit_tiny_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
'deit_small_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
'deit_base_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
'deit_base_patch16_384': _cfg(
url='', # no weights yet
input_size=(3, 384, 384)),
}
@ -433,7 +441,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
@register_model
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
num_classes = kwargs.get('num_classes', 21843)
num_classes = kwargs.pop('num_classes', 21843)
model = VisionTransformer(
patch_size=16, num_classes=num_classes, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
representation_size=768, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
@ -446,7 +454,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
@register_model
def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
num_classes = kwargs.get('num_classes', 21843)
num_classes = kwargs.pop('num_classes', 21843)
model = VisionTransformer(
img_size=224, num_classes=num_classes, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
qkv_bias=True, representation_size=768, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
@ -458,7 +466,7 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
@register_model
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
num_classes = kwargs.get('num_classes', 21843)
num_classes = kwargs.pop('num_classes', 21843)
model = VisionTransformer(
patch_size=16, num_classes=num_classes, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
representation_size=1024, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
@ -482,7 +490,7 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
@register_model
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
num_classes = kwargs.get('num_classes', 21843)
num_classes = kwargs.pop('num_classes', 21843)
model = VisionTransformer(
img_size=224, patch_size=14, num_classes=num_classes, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
qkv_bias=True, representation_size=1280, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
@ -495,7 +503,7 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
@register_model
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
num_classes = kwargs.get('num_classes', 21843)
num_classes = kwargs.pop('num_classes', 21843)
backbone = ResNetV2(
layers=(3, 4, 9), preact=False, stem_type='same', conv_layer=StdConv2dSame, num_classes=0, global_pool='')
model = VisionTransformer(
@ -559,3 +567,51 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs):
img_size=224, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
model.default_cfg = default_cfgs['vit_base_resnet50d_224']
return model
@register_model
def deit_tiny_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['deit_tiny_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model'])
return model
@register_model
def deit_small_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['deit_small_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model'])
return model
@register_model
def deit_base_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['deit_base_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model'])
return model
@register_model
def deit_base_patch16_384(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['deit_base_patch16_384']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model'])
return model

@ -28,7 +28,7 @@ import torch.nn as nn
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
@ -64,8 +64,14 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# Dataset / Model parameters
parser.add_argument('data', metavar='DIR',
parser.add_argument('data_dir', metavar='DIR',
help='path to dataset')
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)')
parser.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)')
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--pretrained', action='store_true', default=False,
@ -437,19 +443,10 @@ def main():
_logger.info('Scheduled epochs: {}'.format(num_epochs))
# create the train and eval datasets
train_dir = os.path.join(args.data, 'train')
if not os.path.exists(train_dir):
_logger.error('Training folder does not exist at: {}'.format(train_dir))
exit(1)
dataset_train = ImageDataset(train_dir)
eval_dir = os.path.join(args.data, 'val')
if not os.path.isdir(eval_dir):
eval_dir = os.path.join(args.data, 'validation')
if not os.path.isdir(eval_dir):
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
exit(1)
dataset_eval = ImageDataset(eval_dir)
dataset_train = create_dataset(
args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size)
dataset_eval = create_dataset(
args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)
# setup mixup / cutmix
collate_fn = None
@ -553,10 +550,10 @@ def main():
try:
for epoch in range(start_epoch, num_epochs):
if args.distributed:
loader_train.sampler.set_epoch(epoch)
if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
loader_train.set_epoch(epoch)
train_metrics = train_epoch(
train_metrics = train_one_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
@ -594,7 +591,7 @@ def main():
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def train_epoch(
def train_one_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None):

@ -20,7 +20,7 @@ from collections import OrderedDict
from contextlib import suppress
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import ImageDataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
has_apex = False
@ -44,7 +44,11 @@ _logger = logging.getLogger('validate')
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92',
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--split', metavar='NAME', default='validation',
help='dataset split (default: validation)')
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
help='model architecture (default: dpn92)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 2)')
@ -159,7 +163,9 @@ def validate(args):
criterion = nn.CrossEntropyLoss().cuda()
dataset = ImageDataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
dataset = create_dataset(
root=args.data, name=args.dataset, split=args.split,
load_bytes=args.tf_preprocessing, class_map=args.class_map)
if args.valid_labels:
with open(args.valid_labels, 'r') as f:

Loading…
Cancel
Save