Merge pull request #323 from rwightman/imagenet21k_datasets_more

BiT (Big Transfer) ResNetV2 models, Official ViT Hybrid R50 weights, VIT IN21K weights updated w/ repr layer, ImageNet21k and dataset / parser refactor
pull/401/head
Ross Wightman 4 years ago committed by GitHub
commit 9a38416fbd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,6 +2,19 @@
## What's New ## What's New
### Jan 25, 2021
* Add ResNetV2 Big Transfer (BiT) models w/ ImageNet-1k and 21k weights from https://github.com/google-research/big_transfer
* Add official R50+ViT-B/16 hybrid models + weights from https://github.com/google-research/vision_transformer
* ImageNet-21k ViT weights are added w/ model defs and representation layer (pre logits) support
* NOTE: ImageNet-21k classifier heads were zero'd in original weights, they are only useful for transfer learning
* Add model defs and weights for DeiT Vision Transformer models from https://github.com/facebookresearch/deit
* Refactor dataset classes into ImageDataset/IterableImageDataset + dataset specific parser classes
* Add Tensorflow-Datasets (TFDS) wrapper to allow use of TFDS image classification sets with train script
* Ex: `train.py /data/tfds --dataset tfds/oxford_iiit_pet --val-split test --model resnet50 -b 256 --amp --num-classes 37 --opt adamw --lr 3e-4 --weight-decay .001 --pretrained -j 2`
* Add improved .tar dataset parser that reads images from .tar, folder of .tar files, or .tar within .tar
* Run validation on full ImageNet-21k directly from tar w/ BiT model: `validate.py /data/fall11_whole.tar --model resnetv2_50x1_bitm_in21k --amp`
* Models in this update should be stable w/ possible exception of ViT/BiT, possibility of some regressions with train/val scripts and dataset handling
### Jan 3, 2021 ### Jan 3, 2021
* Add SE-ResNet-152D weights * Add SE-ResNet-152D weights
* 256x256 val, 0.94 crop top-1 - 83.75 * 256x256 val, 0.94 crop top-1 - 83.75
@ -130,7 +143,9 @@ All model architecture families include variants with pretrained weights. The ar
A full version of the list below with source links can be found in the [documentation](https://rwightman.github.io/pytorch-image-models/models/). A full version of the list below with source links can be found in the [documentation](https://rwightman.github.io/pytorch-image-models/models/).
* Big Transfer ResNetV2 (BiT) - https://arxiv.org/abs/1912.11370
* CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929 * CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
* DeiT (Vision Transformer) - https://arxiv.org/abs/2012.12877
* DenseNet - https://arxiv.org/abs/1608.06993 * DenseNet - https://arxiv.org/abs/1608.06993
* DLA - https://arxiv.org/abs/1707.06484 * DLA - https://arxiv.org/abs/1707.06484
* DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629 * DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629
@ -242,6 +257,10 @@ One of the greatest assets of PyTorch is the community and their contributions.
* Albumentations - https://github.com/albumentations-team/albumentations * Albumentations - https://github.com/albumentations-team/albumentations
* Kornia - https://github.com/kornia/kornia * Kornia - https://github.com/kornia/kornia
### Knowledge Distillation
* RepDistiller - https://github.com/HobbitLong/RepDistiller
* torchdistill - https://github.com/yoshitomo-matsubara/torchdistill
### Metric Learning ### Metric Learning
* PyTorch Metric Learning - https://github.com/KevinMusgrave/pytorch-metric-learning * PyTorch Metric Learning - https://github.com/KevinMusgrave/pytorch-metric-learning

@ -10,6 +10,10 @@ Most included models have pretrained weights. The weights are either:
The validation results for the pretrained weights can be found [here](results.md) The validation results for the pretrained weights can be found [here](results.md)
## Big Transfer ResNetV2 (BiT) [[resnetv2.py](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/resnetv2.py)]
* Paper: `Big Transfer (BiT): General Visual Representation Learning` - https://arxiv.org/abs/1912.11370
* Reference code: https://github.com/google-research/big_transfer
## Cross-Stage Partial Networks [[cspnet.py](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/cspnet.py)] ## Cross-Stage Partial Networks [[cspnet.py](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/cspnet.py)]
* Paper: `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929 * Paper: `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929
* Reference impl: https://github.com/WongKinYiu/CrossStagePartialNetworks * Reference impl: https://github.com/WongKinYiu/CrossStagePartialNetworks

@ -13,7 +13,7 @@ import numpy as np
import torch import torch
from timm.models import create_model, apply_test_time_pool from timm.models import create_model, apply_test_time_pool
from timm.data import Dataset, create_loader, resolve_data_config from timm.data import ImageDataset, create_loader, resolve_data_config
from timm.utils import AverageMeter, setup_default_logging from timm.utils import AverageMeter, setup_default_logging
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@ -83,7 +83,7 @@ def main():
model = model.cuda() model = model.cuda()
loader = create_loader( loader = create_loader(
Dataset(args.data), ImageDataset(args.data),
input_size=config['input_size'], input_size=config['input_size'],
batch_size=args.batch_size, batch_size=args.batch_size,
use_prefetcher=True, use_prefetcher=True,

File diff suppressed because it is too large Load Diff

@ -13,11 +13,16 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
# transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = ['vit_*']
# exclude models that cause specific test failures
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', 'vit_*'] EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm'] + NON_STD_FILTERS
else: else:
EXCLUDE_FILTERS = ['vit_*'] EXCLUDE_FILTERS = NON_STD_FILTERS
MAX_FWD_SIZE = 384 MAX_FWD_SIZE = 384
MAX_BWD_SIZE = 128 MAX_BWD_SIZE = 128
MAX_FWD_FEAT_SIZE = 448 MAX_FWD_FEAT_SIZE = 448
@ -68,7 +73,7 @@ def test_model_backward(model_name, batch_size):
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=['vit_*'])) @pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
def test_model_default_cfgs(model_name, batch_size): def test_model_default_cfgs(model_name, batch_size):
"""Run a single forward pass with each model""" """Run a single forward pass with each model"""
@ -121,7 +126,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
create_model(model_name, pretrained=True, in_chans=in_chans) create_model(model_name, pretrained=True, in_chans=in_chans)
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=['vit_*'])) @pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=NON_STD_FILTERS))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
def test_model_features_pretrained(model_name, batch_size): def test_model_features_pretrained(model_name, batch_size):
"""Create that pretrained weights load when features_only==True.""" """Create that pretrained weights load when features_only==True."""

@ -1,10 +1,12 @@
from .constants import * from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform
from .config import resolve_data_config from .config import resolve_data_config
from .dataset import Dataset, DatasetTar, AugMixDataset from .constants import *
from .transforms import * from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset
from .loader import create_loader from .loader import create_loader
from .transforms_factory import create_transform
from .mixup import Mixup, FastCollateMixup from .mixup import Mixup, FastCollateMixup
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ from .parsers import create_parser
rand_augment_transform, auto_augment_transform
from .real_labels import RealLabelsImagenet from .real_labels import RealLabelsImagenet
from .transforms import *
from .transforms_factory import create_transform

@ -3,172 +3,106 @@
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import torch.utils.data as data import torch.utils.data as data
import os import os
import re
import torch import torch
import tarfile import logging
from PIL import Image from PIL import Image
from .parsers import create_parser
IMG_EXTENSIONS = ['.png', '.jpg', '.jpeg'] _logger = logging.getLogger(__name__)
def natural_key(string_):
"""See http://www.codinghorror.com/blog/archives/001018.html"""
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
labels = []
filenames = []
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
for f in files:
base, ext = os.path.splitext(f)
if ext.lower() in types:
filenames.append(os.path.join(root, f))
labels.append(label)
if class_to_idx is None:
# building class index
unique_labels = set(labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
if sort:
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
return images_and_targets, class_to_idx
def load_class_map(filename, root=''):
class_map_path = filename
if not os.path.exists(class_map_path):
class_map_path = os.path.join(root, filename)
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
class_map_ext = os.path.splitext(filename)[-1].lower()
if class_map_ext == '.txt':
with open(class_map_path) as f:
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
else:
assert False, 'Unsupported class map extension'
return class_to_idx
class Dataset(data.Dataset): _ERROR_RETRY = 50
class ImageDataset(data.Dataset):
def __init__( def __init__(
self, self,
root, root,
parser=None,
class_map='',
load_bytes=False, load_bytes=False,
transform=None, transform=None,
class_map=''): ):
if parser is None or isinstance(parser, str):
class_to_idx = None parser = create_parser(parser or '', root=root, class_map=class_map)
if class_map: self.parser = parser
class_to_idx = load_class_map(class_map, root)
images, class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
if len(images) == 0:
raise RuntimeError(f'Found 0 images in subfolders of {root}. '
f'Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
self.root = root
self.samples = images
self.imgs = self.samples # torchvision ImageFolder compat
self.class_to_idx = class_to_idx
self.load_bytes = load_bytes self.load_bytes = load_bytes
self.transform = transform self.transform = transform
self._consecutive_errors = 0
def __getitem__(self, index): def __getitem__(self, index):
path, target = self.samples[index] img, target = self.parser[index]
img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB') try:
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.parser))
else:
raise e
self._consecutive_errors = 0
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
if target is None: if target is None:
target = torch.zeros(1).long() target = torch.tensor(-1, dtype=torch.long)
return img, target return img, target
def __len__(self): def __len__(self):
return len(self.samples) return len(self.parser)
def filename(self, index, basename=False, absolute=False): def filename(self, index, basename=False, absolute=False):
filename = self.samples[index][0] return self.parser.filename(index, basename, absolute)
if basename:
filename = os.path.basename(filename)
elif not absolute:
filename = os.path.relpath(filename, self.root)
return filename
def filenames(self, basename=False, absolute=False): def filenames(self, basename=False, absolute=False):
fn = lambda x: x return self.parser.filenames(basename, absolute)
if basename:
fn = os.path.basename
elif not absolute: class IterableImageDataset(data.IterableDataset):
fn = lambda x: os.path.relpath(x, self.root)
return [fn(x[0]) for x in self.samples] def __init__(
self,
root,
def _extract_tar_info(tarfile, class_to_idx=None, sort=True): parser=None,
files = [] split='train',
labels = [] is_training=False,
for ti in tarfile.getmembers(): batch_size=None,
if not ti.isfile(): class_map='',
continue load_bytes=False,
dirname, basename = os.path.split(ti.path) transform=None,
label = os.path.basename(dirname) ):
ext = os.path.splitext(basename)[1] assert parser is not None
if ext.lower() in IMG_EXTENSIONS: if isinstance(parser, str):
files.append(ti) self.parser = create_parser(
labels.append(label) parser, root=root, split=split, is_training=is_training, batch_size=batch_size)
if class_to_idx is None: else:
unique_labels = set(labels) self.parser = parser
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
if sort:
tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
return tarinfo_and_targets, class_to_idx
class DatasetTar(data.Dataset):
def __init__(self, root, load_bytes=False, transform=None, class_map=''):
class_to_idx = None
if class_map:
class_to_idx = load_class_map(class_map, root)
assert os.path.isfile(root)
self.root = root
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
self.samples, self.class_to_idx = _extract_tar_info(tf, class_to_idx)
self.imgs = self.samples
self.tarfile = None # lazy init in __getitem__
self.load_bytes = load_bytes
self.transform = transform self.transform = transform
self._consecutive_errors = 0
def __getitem__(self, index): def __iter__(self):
if self.tarfile is None: for img, target in self.parser:
self.tarfile = tarfile.open(self.root)
tarinfo, target = self.samples[index]
iob = self.tarfile.extractfile(tarinfo)
img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB')
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
if target is None: if target is None:
target = torch.zeros(1).long() target = torch.tensor(-1, dtype=torch.long)
return img, target yield img, target
def __len__(self): def __len__(self):
return len(self.samples) if hasattr(self.parser, '__len__'):
return len(self.parser)
else:
return 0
def filename(self, index, basename=False): def filename(self, index, basename=False, absolute=False):
filename = self.samples[index][0].name assert False, 'Filename lookup by index not supported, use filenames().'
if basename:
filename = os.path.basename(filename)
return filename
def filenames(self, basename=False): def filenames(self, basename=False, absolute=False):
fn = os.path.basename if basename else lambda x: x return self.parser.filenames(basename, absolute)
return [fn(x[0].name) for x in self.samples]
class AugMixDataset(torch.utils.data.Dataset): class AugMixDataset(torch.utils.data.Dataset):

@ -0,0 +1,29 @@
import os
from .dataset import IterableImageDataset, ImageDataset
def _search_split(root, split):
# look for sub-folder with name of split in root and use that if it exists
split_name = split.split('[')[0]
try_root = os.path.join(root, split_name)
if os.path.exists(try_root):
return try_root
if split_name == 'validation':
try_root = os.path.join(root, 'val')
if os.path.exists(try_root):
return try_root
return root
def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs):
name = name.lower()
if name.startswith('tfds'):
ds = IterableImageDataset(
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
else:
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
if search_split and os.path.isdir(root):
root = _search_split(root, split)
ds = ImageDataset(root, parser=name, **kwargs)
return ds

@ -153,7 +153,8 @@ def create_loader(
pin_memory=False, pin_memory=False,
fp16=False, fp16=False,
tf_preprocessing=False, tf_preprocessing=False,
use_multi_epochs_loader=False use_multi_epochs_loader=False,
persistent_workers=True,
): ):
re_num_splits = 0 re_num_splits = 0
if re_split: if re_split:
@ -183,7 +184,7 @@ def create_loader(
) )
sampler = None sampler = None
if distributed: if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
if is_training: if is_training:
sampler = torch.utils.data.distributed.DistributedSampler(dataset) sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else: else:
@ -199,16 +200,20 @@ def create_loader(
if use_multi_epochs_loader: if use_multi_epochs_loader:
loader_class = MultiEpochsDataLoader loader_class = MultiEpochsDataLoader
loader = loader_class( loader_args = dict(
dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=sampler is None and is_training, shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
num_workers=num_workers, num_workers=num_workers,
sampler=sampler, sampler=sampler,
collate_fn=collate_fn, collate_fn=collate_fn,
pin_memory=pin_memory, pin_memory=pin_memory,
drop_last=is_training, drop_last=is_training,
) persistent_workers=persistent_workers)
try:
loader = loader_class(dataset, **loader_args)
except TypeError as e:
loader_args.pop('persistent_workers') # only in Pytorch 1.7+
loader = loader_class(dataset, **loader_args)
if use_prefetcher: if use_prefetcher:
prefetch_re_prob = re_prob if is_training and not no_aug else 0. prefetch_re_prob = re_prob if is_training and not no_aug else 0.
loader = PrefetchLoader( loader = PrefetchLoader(

@ -0,0 +1 @@
from .parser_factory import create_parser

@ -0,0 +1,16 @@
import os
def load_class_map(filename, root=''):
class_map_path = filename
if not os.path.exists(class_map_path):
class_map_path = os.path.join(root, filename)
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
class_map_ext = os.path.splitext(filename)[-1].lower()
if class_map_ext == '.txt':
with open(class_map_path) as f:
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
else:
assert False, 'Unsupported class map extension'
return class_to_idx

@ -0,0 +1 @@
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg')

@ -0,0 +1,17 @@
from abc import abstractmethod
class Parser:
def __init__(self):
pass
@abstractmethod
def _filename(self, index, basename=False, absolute=False):
pass
def filename(self, index, basename=False, absolute=False):
return self._filename(index, basename=basename, absolute=absolute)
def filenames(self, basename=False, absolute=False):
return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))]

@ -0,0 +1,29 @@
import os
from .parser_image_folder import ParserImageFolder
from .parser_image_tar import ParserImageTar
from .parser_image_in_tar import ParserImageInTar
def create_parser(name, root, split='train', **kwargs):
name = name.lower()
name = name.split('/', 2)
prefix = ''
if len(name) > 1:
prefix = name[0]
name = name[-1]
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
# explicitly select other options shortly
if prefix == 'tfds':
from .parser_tfds import ParserTfds # defer tensorflow import
parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs)
else:
assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
# FIXME support split here, in parser?
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
parser = ParserImageInTar(root, **kwargs)
else:
parser = ParserImageFolder(root, **kwargs)
return parser

@ -0,0 +1,69 @@
""" A dataset parser that reads images from folders
Folders are scannerd recursively to find image files. Labels are based
on the folder hierarchy, just leaf folders by default.
Hacked together by / Copyright 2020 Ross Wightman
"""
import os
from timm.utils.misc import natural_key
from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
labels = []
filenames = []
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
for f in files:
base, ext = os.path.splitext(f)
if ext.lower() in types:
filenames.append(os.path.join(root, f))
labels.append(label)
if class_to_idx is None:
# building class index
unique_labels = set(labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
if sort:
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
return images_and_targets, class_to_idx
class ParserImageFolder(Parser):
def __init__(
self,
root,
class_map=''):
super().__init__()
self.root = root
class_to_idx = None
if class_map:
class_to_idx = load_class_map(class_map, root)
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
if len(self.samples) == 0:
raise RuntimeError(
f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
def __getitem__(self, index):
path, target = self.samples[index]
return open(path, 'rb'), target
def __len__(self):
return len(self.samples)
def _filename(self, index, basename=False, absolute=False):
filename = self.samples[index][0]
if basename:
filename = os.path.basename(filename)
elif not absolute:
filename = os.path.relpath(filename, self.root)
return filename

@ -0,0 +1,222 @@
""" A dataset parser that reads tarfile based datasets
This parser can read and extract image samples from:
* a single tar of image files
* a folder of multiple tarfiles containing imagefiles
* a tar of tars containing image files
Labels are based on the combined folder and/or tar name structure.
Hacked together by / Copyright 2020 Ross Wightman
"""
import os
import tarfile
import pickle
import logging
import numpy as np
from glob import glob
from typing import List, Dict
from timm.utils.misc import natural_key
from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
_logger = logging.getLogger(__name__)
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
class TarState:
def __init__(self, tf: tarfile.TarFile = None, ti: tarfile.TarInfo = None):
self.tf: tarfile.TarFile = tf
self.ti: tarfile.TarInfo = ti
self.children: Dict[str, TarState] = {} # child states (tars within tars)
def reset(self):
self.tf = None
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS):
sample_count = 0
for i, ti in enumerate(tf):
if not ti.isfile():
continue
dirname, basename = os.path.split(ti.path)
name, ext = os.path.splitext(basename)
ext = ext.lower()
if ext == '.tar':
with tarfile.open(fileobj=tf.extractfile(ti), mode='r|') as ctf:
child_info = dict(
name=ti.name, path=os.path.join(parent_info['path'], name), ti=ti, children=[], samples=[])
sample_count += _extract_tarinfo(ctf, child_info, extensions=extensions)
_logger.debug(f'{i}/?. Extracted child tarinfos from {ti.name}. {len(child_info["samples"])} images.')
parent_info['children'].append(child_info)
elif ext in extensions:
parent_info['samples'].append(ti)
sample_count += 1
return sample_count
def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True):
root_is_tar = False
if os.path.isfile(root):
assert os.path.splitext(root)[-1].lower() == '.tar'
tar_filenames = [root]
root, root_name = os.path.split(root)
root_name = os.path.splitext(root_name)[0]
root_is_tar = True
else:
root_name = root.strip(os.path.sep).split(os.path.sep)[-1]
tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True)
num_tars = len(tar_filenames)
tar_bytes = sum([os.path.getsize(f) for f in tar_filenames])
assert num_tars, f'No .tar files found at specified path ({root}).'
_logger.info(f'Scanning {tar_bytes/1024**2:.2f}MB of tar files...')
info = dict(tartrees=[])
cache_path = ''
if cache_tarinfo is None:
cache_tarinfo = True if tar_bytes > 10*1024**3 else False # FIXME magic number, 10GB
if cache_tarinfo:
cache_filename = '_' + root_name + CACHE_FILENAME_SUFFIX
cache_path = os.path.join(root, cache_filename)
if os.path.exists(cache_path):
_logger.info(f'Reading tar info from cache file {cache_path}.')
with open(cache_path, 'rb') as pf:
info = pickle.load(pf)
assert len(info['tartrees']) == num_tars, "Cached tartree len doesn't match number of tarfiles"
else:
for i, fn in enumerate(tar_filenames):
path = '' if root_is_tar else os.path.splitext(os.path.basename(fn))[0]
with tarfile.open(fn, mode='r|') as tf: # tarinfo scans done in streaming mode
parent_info = dict(name=os.path.relpath(fn, root), path=path, ti=None, children=[], samples=[])
num_samples = _extract_tarinfo(tf, parent_info, extensions=extensions)
num_children = len(parent_info["children"])
_logger.debug(
f'{i}/{num_tars}. Extracted tarinfos from {fn}. {num_children} children, {num_samples} samples.')
info['tartrees'].append(parent_info)
if cache_path:
_logger.info(f'Writing tar info to cache file {cache_path}.')
with open(cache_path, 'wb') as pf:
pickle.dump(info, pf)
samples = []
labels = []
build_class_map = False
if class_name_to_idx is None:
build_class_map = True
# Flatten tartree info into lists of samples and targets w/ targets based on label id via
# class map arg or from unique paths.
# NOTE: currently only flattening up to two-levels, filesystem .tars and then one level of sub-tar children
# this covers my current use cases and keeps things a little easier to test for now.
tarfiles = []
def _label_from_paths(*path, leaf_only=True):
path = os.path.join(*path).strip(os.path.sep)
return path.split(os.path.sep)[-1] if leaf_only else path.replace(os.path.sep, '_')
def _add_samples(info, fn):
added = 0
for s in info['samples']:
label = _label_from_paths(info['path'], os.path.dirname(s.path))
if not build_class_map and label not in class_name_to_idx:
continue
samples.append((s, fn, info['ti']))
labels.append(label)
added += 1
return added
_logger.info(f'Collecting samples and building tar states.')
for parent_info in info['tartrees']:
# if tartree has children, we assume all samples are at the child level
tar_name = None if root_is_tar else parent_info['name']
tar_state = TarState()
parent_added = 0
for child_info in parent_info['children']:
child_added = _add_samples(child_info, fn=tar_name)
if child_added:
tar_state.children[child_info['name']] = TarState(ti=child_info['ti'])
parent_added += child_added
parent_added += _add_samples(parent_info, fn=tar_name)
if parent_added:
tarfiles.append((tar_name, tar_state))
del info
if build_class_map:
# build class index
sorted_labels = list(sorted(set(labels), key=natural_key))
class_name_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
_logger.info(f'Mapping targets and sorting samples.')
samples_and_targets = [(s, class_name_to_idx[l]) for s, l in zip(samples, labels) if l in class_name_to_idx]
if sort:
samples_and_targets = sorted(samples_and_targets, key=lambda k: natural_key(k[0][0].path))
samples, targets = zip(*samples_and_targets)
samples = np.array(samples)
targets = np.array(targets)
_logger.info(f'Finished processing {len(samples)} samples across {len(tarfiles)} tar files.')
return samples, targets, class_name_to_idx, tarfiles
class ParserImageInTar(Parser):
""" Multi-tarfile dataset parser where there is one .tar file per class
"""
def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):
super().__init__()
class_name_to_idx = None
if class_map:
class_name_to_idx = load_class_map(class_map, root)
self.root = root
self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
self.root,
class_name_to_idx=class_name_to_idx,
cache_tarinfo=cache_tarinfo,
extensions=IMG_EXTENSIONS)
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
if len(tarfiles) == 1 and tarfiles[0][0] is None:
self.root_is_tar = True
self.tar_state = tarfiles[0][1]
else:
self.root_is_tar = False
self.tar_state = dict(tarfiles)
self.cache_tarfiles = cache_tarfiles
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample = self.samples[index]
target = self.targets[index]
sample_ti, parent_fn, child_ti = sample
parent_abs = os.path.join(self.root, parent_fn) if parent_fn else self.root
tf = None
cache_state = None
if self.cache_tarfiles:
cache_state = self.tar_state if self.root_is_tar else self.tar_state[parent_fn]
tf = cache_state.tf
if tf is None:
tf = tarfile.open(parent_abs)
if self.cache_tarfiles:
cache_state.tf = tf
if child_ti is not None:
ctf = cache_state.children[child_ti.name].tf if self.cache_tarfiles else None
if ctf is None:
ctf = tarfile.open(fileobj=tf.extractfile(child_ti))
if self.cache_tarfiles:
cache_state.children[child_ti.name].tf = ctf
tf = ctf
return tf.extractfile(sample_ti), target
def _filename(self, index, basename=False, absolute=False):
filename = self.samples[index][0].name
if basename:
filename = os.path.basename(filename)
return filename

@ -0,0 +1,72 @@
""" A dataset parser that reads single tarfile based datasets
This parser can read datasets consisting if a single tarfile containing images.
I am planning to deprecated it in favour of ParerImageInTar.
Hacked together by / Copyright 2020 Ross Wightman
"""
import os
import tarfile
from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
from timm.utils.misc import natural_key
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
files = []
labels = []
for ti in tarfile.getmembers():
if not ti.isfile():
continue
dirname, basename = os.path.split(ti.path)
label = os.path.basename(dirname)
ext = os.path.splitext(basename)[1]
if ext.lower() in IMG_EXTENSIONS:
files.append(ti)
labels.append(label)
if class_to_idx is None:
unique_labels = set(labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
if sort:
tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
return tarinfo_and_targets, class_to_idx
class ParserImageTar(Parser):
""" Single tarfile dataset where classes are mapped to folders within tar
NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can
operate on folders of tars or tars in tars.
"""
def __init__(self, root, class_map=''):
super().__init__()
class_to_idx = None
if class_map:
class_to_idx = load_class_map(class_map, root)
assert os.path.isfile(root)
self.root = root
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx)
self.imgs = self.samples
self.tarfile = None # lazy init in __getitem__
def __getitem__(self, index):
if self.tarfile is None:
self.tarfile = tarfile.open(self.root)
tarinfo, target = self.samples[index]
fileobj = self.tarfile.extractfile(tarinfo)
return fileobj, target
def __len__(self):
return len(self.samples)
def _filename(self, index, basename=False, absolute=False):
filename = self.samples[index][0].name
if basename:
filename = os.path.basename(filename)
return filename

@ -0,0 +1,201 @@
""" Dataset parser interface that wraps TFDS datasets
Wraps many (most?) TFDS image-classification datasets
from https://github.com/tensorflow/datasets
https://www.tensorflow.org/datasets/catalog/overview#image_classification
Hacked together by / Copyright 2020 Ross Wightman
"""
import os
import io
import math
import torch
import torch.distributed as dist
from PIL import Image
try:
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
import tensorflow_datasets as tfds
except ImportError as e:
print(e)
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
exit(1)
from .parser import Parser
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue
PREFETCH_SIZE = 4096 # samples to prefetch
class ParserTfds(Parser):
""" Wrap Tensorflow Datasets for use in PyTorch
There several things to be aware of:
* To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of
dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last
https://github.com/pytorch/pytorch/issues/33413
* With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch
from each worker could be a different size. For training this is worked around by option above, for
validation extra samples are inserted iff distributed mode is enabled so that the batches being reduced
across replicas are of same size. This will slightly alter the results, distributed validation will not be
100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse
since there are up to N * J extra samples with IterableDatasets.
* The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of
replicas and dataloader workers you can use. For really small datasets that only contain a few shards
you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the
benefit of distributed training or fast dataloading should be much less for small datasets.
* This wrapper is currently configured to return individual, decompressed image samples from the TFDS
dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible
to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream
components.
"""
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None):
super().__init__()
self.root = root
self.split = split
self.shuffle = shuffle
self.is_training = is_training
if self.is_training:
assert batch_size is not None,\
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
self.batch_size = batch_size
self.builder = tfds.builder(name, data_dir=root)
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
# download_and_prepare() by default here as it's caused issues generating unwanted paths.
self.num_samples = self.builder.info.splits[split].num_examples
self.ds = None # initialized lazily on each dataloader worker process
self.worker_info = None
self.dist_rank = 0
self.dist_num_replicas = 1
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
self.dist_rank = dist.get_rank()
self.dist_num_replicas = dist.get_world_size()
def _lazy_init(self):
""" Lazily initialize the dataset.
This is necessary to init the Tensorflow dataset pipeline in the (dataloader) process that
will be using the dataset instance. The __init__ method is called on the main process,
this will be called in a dataloader worker process.
NOTE: There will be problems if you try to re-use this dataset across different loader/worker
instances once it has been initialized. Do not call any dataset methods that can call _lazy_init
before it is passed to dataloader.
"""
worker_info = torch.utils.data.get_worker_info()
# setup input context to split dataset across distributed processes
split = self.split
num_workers = 1
if worker_info is not None:
self.worker_info = worker_info
num_workers = worker_info.num_workers
worker_id = worker_info.id
# FIXME I need to spend more time figuring out the best way to distribute/split data across
# combo of distributed replicas + dataloader worker processes
"""
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
between the splits each iteration, but that understanding could be wrong.
Possible split options include:
* InputContext for both distributed & worker processes (current)
* InputContext for distributed and sub-splits for worker processes
* sub-splits for both
"""
# split_size = self.num_samples // num_workers
# start = worker_id * split_size
# if worker_id == num_workers - 1:
# split = split + '[{}:]'.format(start)
# else:
# split = split + '[{}:{}]'.format(start, start + split_size)
input_context = tf.distribute.InputContext(
num_input_pipelines=self.dist_num_replicas * num_workers,
input_pipeline_id=self.dist_rank * num_workers + worker_id,
num_replicas_in_sync=self.dist_num_replicas # FIXME does this have any impact?
)
read_config = tfds.ReadConfig(input_context=input_context)
ds = self.builder.as_dataset(split=split, shuffle_files=self.shuffle, read_config=read_config)
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
ds.options().experimental_threading.max_intra_op_parallelism = 1
if self.is_training:
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually
if self.shuffle:
ds = ds.shuffle(min(self.num_samples // self._num_pipelines, SHUFFLE_SIZE), seed=0)
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
self.ds = tfds.as_numpy(ds)
def __iter__(self):
if self.ds is None:
self._lazy_init()
# compute a rounded up sample count that is used to:
# 1. make batches even cross workers & replicas in distributed validation.
# This adds extra samples and will slightly alter validation results.
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
# batches are produced (underlying tfds iter wraps around)
target_sample_count = math.ceil(self.num_samples / self._num_pipelines)
if self.is_training:
# round up to nearest batch_size per worker-replica
target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size
sample_count = 0
for sample in self.ds:
img = Image.fromarray(sample['image'], mode='RGB')
yield img, sample['label']
sample_count += 1
if self.is_training and sample_count >= target_sample_count:
# Need to break out of loop when repeat() is enabled for training w/ oversampling
# this results in extra samples per epoch but seems more desirable than dropping
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
break
if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count:
# Validation batch padding only done for distributed training where results are reduced across nodes.
# For single process case, it won't matter if workers return different batch sizes.
# FIXME this needs more testing, possible for sharding / split api to cause differences of > 1?
assert target_sample_count - sample_count == 1 # should only be off by 1 or sharding is not optimal
yield img, sample['label'] # yield prev sample again
sample_count += 1
@property
def _num_workers(self):
return 1 if self.worker_info is None else self.worker_info.num_workers
@property
def _num_pipelines(self):
return self._num_workers * self.dist_num_replicas
def __len__(self):
# this is just an estimate and does not factor in extra samples added to pad batches based on
# complete worker & replica info (not available until init in dataloader).
return math.ceil(self.num_samples / self.dist_num_replicas)
def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to samples
def filenames(self, basename=False, absolute=False):
""" Return all filenames in dataset, overrides base"""
if self.ds is None:
self._lazy_init()
names = []
for sample in self.ds:
if len(names) > self.num_samples:
break # safety for ds.repeat() case
if 'file_name' in sample:
name = sample['file_name']
elif 'filename' in sample:
name = sample['filename']
elif 'id' in sample:
name = sample['id']
else:
assert False, "No supported name field present"
names.append(name)
return names

@ -16,6 +16,7 @@ from .regnet import *
from .res2net import * from .res2net import *
from .resnest import * from .resnest import *
from .resnet import * from .resnet import *
from .resnetv2 import *
from .rexnet import * from .rexnet import *
from .selecsls import * from .selecsls import *
from .senet import * from .senet import *

@ -6,8 +6,6 @@ from .layers import set_layer_config
def create_model( def create_model(
model_name, model_name,
pretrained=False, pretrained=False,
num_classes=1000,
in_chans=3,
checkpoint_path='', checkpoint_path='',
scriptable=None, scriptable=None,
exportable=None, exportable=None,
@ -18,8 +16,6 @@ def create_model(
Args: Args:
model_name (str): name of model to instantiate model_name (str): name of model to instantiate
pretrained (bool): load pretrained ImageNet-1k weights if true pretrained (bool): load pretrained ImageNet-1k weights if true
num_classes (int): number of classes for final fully connected layer (default: 1000)
in_chans (int): number of input channels / colors (default: 3)
checkpoint_path (str): path of checkpoint to load after model is initialized checkpoint_path (str): path of checkpoint to load after model is initialized
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
@ -30,7 +26,7 @@ def create_model(
global_pool (str): global pool type (default: 'avg') global_pool (str): global pool type (default: 'avg')
**: other kwargs are model specific **: other kwargs are model specific
""" """
model_args = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans) model_args = dict(pretrained=pretrained)
# Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args # Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3']) is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3'])

@ -11,7 +11,11 @@ from typing import Callable
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.model_zoo as model_zoo from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX
try:
from torch.hub import get_dir
except ImportError:
from torch.hub import _get_torch_home as get_dir
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .layers import Conv2dSame, Linear from .layers import Conv2dSame, Linear
@ -88,15 +92,70 @@ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None,
raise FileNotFoundError() raise FileNotFoundError()
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True): def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_hash=False):
r"""Loads a custom (read non .pth) weight file
Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
a passed in custom load fun, or the `load_pretrained` model member fn.
If the object is already present in `model_dir`, it's deserialized and returned.
The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
`hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
Args:
model: The instantiated model to load weights into
cfg (dict): Default pretrained model cfg
load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named
'laod_pretrained' on the model will be called if it exists
progress (bool, optional): whether or not to display a progress bar to stderr. Default: False
check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file. Default: False
"""
if cfg is None: if cfg is None:
cfg = getattr(model, 'default_cfg') cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']: if cfg is None or 'url' not in cfg or not cfg['url']:
_logger.warning("Pretrained model URL is invalid, using random initialization.") _logger.warning("Pretrained model URL does not exist, using random initialization.")
return return
url = cfg['url']
# Issue warning to move data if old env is set
if os.getenv('TORCH_MODEL_ZOO'):
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
if check_hash:
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
hash_prefix = r.group(1) if r else None
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
if load_fn is not None:
load_fn(model, cached_file)
elif hasattr(model, 'load_pretrained'):
model.load_pretrained(cached_file)
else:
_logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu') def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
if cfg is None:
cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']:
_logger.warning("Pretrained model URL does not exist, using random initialization.")
return
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu')
if filter_fn is not None: if filter_fn is not None:
state_dict = filter_fn(state_dict) state_dict = filter_fn(state_dict)
@ -139,6 +198,7 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
classifier_name = cfg['classifier'] classifier_name = cfg['classifier']
if num_classes == 1000 and cfg['num_classes'] == 1001: if num_classes == 1000 and cfg['num_classes'] == 1001:
# FIXME this special case is problematic as number of pretrained weight sources increases
# special case for imagenet trained models with extra background class in pretrained weights # special case for imagenet trained models with extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight'] classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[1:] state_dict[classifier_name + '.weight'] = classifier_weight[1:]
@ -269,6 +329,7 @@ def build_model_with_cfg(
feature_cfg: dict = None, feature_cfg: dict = None,
pretrained_strict: bool = True, pretrained_strict: bool = True,
pretrained_filter_fn: Callable = None, pretrained_filter_fn: Callable = None,
pretrained_custom_load: bool = False,
**kwargs): **kwargs):
pruned = kwargs.pop('pruned', False) pruned = kwargs.pop('pruned', False)
features = False features = False
@ -289,6 +350,9 @@ def build_model_with_cfg(
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats # for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained: if pretrained:
if pretrained_custom_load:
load_custom_pretrained(model)
else:
load_pretrained( load_pretrained(
model, model,
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3), num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),

@ -7,7 +7,7 @@ from .classifier import ClassifierHead, create_classifier
from .cond_conv2d import CondConv2d, get_condconv_initializer from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
set_layer_config set_layer_config
from .conv2d_same import Conv2dSame from .conv2d_same import Conv2dSame, conv2d_same
from .conv_bn_act import ConvBnAct from .conv_bn_act import ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn from .create_act import create_act_layer, get_act_layer, get_act_fn
from .create_attn import create_attn from .create_attn import create_attn
@ -20,8 +20,8 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple
from .inplace_abn import InplaceAbn from .inplace_abn import InplaceAbn
from .linear import Linear from .linear import Linear
from .mixed_conv2d import MixedConv2d from .mixed_conv2d import MixedConv2d
from .norm_act import BatchNormAct2d from .norm_act import BatchNormAct2d, GroupNormAct
from .padding import get_padding from .padding import get_padding, get_same_padding, pad_same
from .pool2d_same import AvgPool2dSame, create_pool2d from .pool2d_same import AvgPool2dSame, create_pool2d
from .se import SEModule from .se import SEModule
from .selective_kernel import SelectiveKernelConv from .selective_kernel import SelectiveKernelConv

@ -9,31 +9,43 @@ from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .linear import Linear from .linear import Linear
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
flatten = not use_conv # flatten when we use a Linear layer after pooling flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
if not pool_type: if not pool_type:
assert num_classes == 0 or use_conv,\ assert num_classes == 0 or use_conv,\
'Pooling can only be disabled if classifier is also removed or conv classifier is used' 'Pooling can only be disabled if classifier is also removed or conv classifier is used'
flatten = False # disable flattening if pooling is pass-through (no pooling) flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten) global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)
num_pooled_features = num_features * global_pool.feat_mult() num_pooled_features = num_features * global_pool.feat_mult()
return global_pool, num_pooled_features
def _create_fc(num_features, num_classes, pool_type='avg', use_conv=False):
if num_classes <= 0: if num_classes <= 0:
fc = nn.Identity() # pass-through (no classifier) fc = nn.Identity() # pass-through (no classifier)
elif use_conv: elif use_conv:
fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True) fc = nn.Conv2d(num_features, num_classes, 1, bias=True)
else: else:
# NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue # NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue
fc = Linear(num_pooled_features, num_classes, bias=True) fc = Linear(num_features, num_classes, bias=True)
return fc
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv)
fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
return global_pool, fc return global_pool, fc
class ClassifierHead(nn.Module): class ClassifierHead(nn.Module):
"""Classifier head w/ configurable global pooling and dropout.""" """Classifier head w/ configurable global pooling and dropout."""
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.): def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
super(ClassifierHead, self).__init__() super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.global_pool, self.fc = create_classifier(in_chs, num_classes, pool_type=pool_type) self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
self.flatten_after_fc = use_conv and pool_type
def forward(self, x): def forward(self, x):
x = self.global_pool(x) x = self.global_pool(x)

@ -68,8 +68,8 @@ class BatchNormAct2d(nn.BatchNorm2d):
class GroupNormAct(nn.GroupNorm): class GroupNormAct(nn.GroupNorm):
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, def __init__(self, num_channels, num_groups, eps=1e-5, affine=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):
super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine) super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine)
if isinstance(act_layer, str): if isinstance(act_layer, str):

@ -403,7 +403,7 @@ class ReductionCell1(nn.Module):
class NASNetALarge(nn.Module): class NASNetALarge(nn.Module):
"""NASNetALarge (6 @ 4032) """ """NASNetALarge (6 @ 4032) """
def __init__(self, num_classes=1000, in_chans=1, stem_size=96, channel_multiplier=2, def __init__(self, num_classes=1000, in_chans=3, stem_size=96, channel_multiplier=2,
num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'): num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'):
super(NASNetALarge, self).__init__() super(NASNetALarge, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes

@ -162,6 +162,12 @@ default_cfgs = {
'seresnet152d_320': _cfg( 'seresnet152d_320': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet152d_ra2-04464dd2.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet152d_ra2-04464dd2.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), crop_pct=1.0, pool_size=(10, 10)), interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), crop_pct=1.0, pool_size=(10, 10)),
'seresnet200d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
'seresnet269d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
# Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py # Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py
@ -216,6 +222,12 @@ default_cfgs = {
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth', url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
interpolation='bicubic', interpolation='bicubic',
first_conv='conv1.0'), first_conv='conv1.0'),
'ecaresnet200d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
'ecaresnet269d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
# Efficient Channel Attention ResNeXts # Efficient Channel Attention ResNeXts
'ecaresnext26tn_32x4d': _cfg( 'ecaresnext26tn_32x4d': _cfg(
@ -1123,6 +1135,26 @@ def ecaresnet101d_pruned(pretrained=False, **kwargs):
return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args) return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args)
@register_model
def ecaresnet200d(pretrained=False, **kwargs):
"""Constructs a ResNet-200-D model with ECA.
"""
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnet200d', pretrained, **model_args)
@register_model
def ecaresnet269d(pretrained=False, **kwargs):
"""Constructs a ResNet-269-D model with ECA.
"""
model_args = dict(
block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnet269d', pretrained, **model_args)
@register_model @register_model
def ecaresnext26tn_32x4d(pretrained=False, **kwargs): def ecaresnext26tn_32x4d(pretrained=False, **kwargs):
"""Constructs an ECA-ResNeXt-26-TN model. """Constructs an ECA-ResNeXt-26-TN model.
@ -1198,6 +1230,26 @@ def seresnet152d(pretrained=False, **kwargs):
return _create_resnet('seresnet152d', pretrained, **model_args) return _create_resnet('seresnet152d', pretrained, **model_args)
@register_model
def seresnet200d(pretrained=False, **kwargs):
"""Constructs a ResNet-200-D model with SE attn.
"""
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnet200d', pretrained, **model_args)
@register_model
def seresnet269d(pretrained=False, **kwargs):
"""Constructs a ResNet-269-D model with SE attn.
"""
model_args = dict(
block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnet269d', pretrained, **model_args)
@register_model @register_model
def seresnet152d_320(pretrained=False, **kwargs): def seresnet152d_320(pretrained=False, **kwargs):
model_args = dict( model_args = dict(

@ -0,0 +1,593 @@
"""Pre-Activation ResNet v2 with GroupNorm and Weight Standardization.
A PyTorch implementation of ResNetV2 adapted from the Google Big-Transfoer (BiT) source code
at https://github.com/google-research/big_transfer to match timm interfaces. The BiT weights have
been included here as pretrained models from their original .NPZ checkpoints.
Additionally, supports non pre-activation bottleneck for use as a backbone for Vision Transfomers (ViT) and
extra padding support to allow porting of official Hybrid ResNet pretrained weights from
https://github.com/google-research/vision_transformer
Thanks to the Google team for the above two repositories and associated papers:
* Big Transfer (BiT): General Visual Representation Learning - https://arxiv.org/abs/1912.11370
* An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale - https://arxiv.org/abs/2010.11929
Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020.
"""
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict # pylint: disable=g-importing-member
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg
from .registry import register_model
from .layers import get_padding, GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, conv2d_same
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 480, 480), 'pool_size': (7, 7),
'crop_pct': 1.0, 'interpolation': 'bilinear',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = {
# pretrained on imagenet21k, finetuned on imagenet1k
'resnetv2_50x1_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz'),
'resnetv2_50x3_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz'),
'resnetv2_101x1_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz'),
'resnetv2_101x3_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz'),
'resnetv2_152x2_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz'),
'resnetv2_152x4_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz'),
# trained on imagenet-21k
'resnetv2_50x1_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz',
num_classes=21843),
'resnetv2_50x3_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz',
num_classes=21843),
'resnetv2_101x1_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz',
num_classes=21843),
'resnetv2_101x3_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz',
num_classes=21843),
'resnetv2_152x2_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz',
num_classes=21843),
'resnetv2_152x4_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz',
num_classes=21843),
# trained on imagenet-1k, NOTE not overly interesting set of weights, leaving disabled for now
# 'resnetv2_50x1_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R50x1.npz'),
# 'resnetv2_50x3_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R50x3.npz'),
# 'resnetv2_101x1_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
# 'resnetv2_101x3_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
# 'resnetv2_152x2_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R152x2.npz'),
# 'resnetv2_152x4_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R152x4.npz'),
}
def make_div(v, divisor=8):
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class StdConv2d(nn.Conv2d):
def __init__(
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5):
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channel, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=bias, groups=groups)
self.eps = eps
def forward(self, x):
w = self.weight
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
w = (w - m) / (torch.sqrt(v) + self.eps)
x = F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
class StdConv2dSame(nn.Conv2d):
"""StdConv2d w/ TF compatible SAME padding. Used for ViT Hybrid model.
"""
def __init__(
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5):
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channel, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=bias, groups=groups)
self.eps = eps
def forward(self, x):
w = self.weight
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
w = (w - m) / (torch.sqrt(v) + self.eps)
x = conv2d_same(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
def tf2th(conv_weights):
"""Possibly convert HWIO to OIHW."""
if conv_weights.ndim == 4:
conv_weights = conv_weights.transpose([3, 2, 0, 1])
return torch.from_numpy(conv_weights)
class PreActBottleneck(nn.Module):
"""Pre-activation (v2) bottleneck block.
Follows the implementation of "Identity Mappings in Deep Residual Networks":
https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua
Except it puts the stride on 3x3 conv when available.
"""
def __init__(
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
super().__init__()
first_dilation = first_dilation or dilation
conv_layer = conv_layer or StdConv2d
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
out_chs = out_chs or in_chs
mid_chs = make_div(out_chs * bottle_ratio)
if proj_layer is not None:
self.downsample = proj_layer(
in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, preact=True,
conv_layer=conv_layer, norm_layer=norm_layer)
else:
self.downsample = None
self.norm1 = norm_layer(in_chs)
self.conv1 = conv_layer(in_chs, mid_chs, 1)
self.norm2 = norm_layer(mid_chs)
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
self.norm3 = norm_layer(mid_chs)
self.conv3 = conv_layer(mid_chs, out_chs, 1)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
def forward(self, x):
x_preact = self.norm1(x)
# shortcut branch
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x_preact)
# residual branch
x = self.conv1(x_preact)
x = self.conv2(self.norm2(x))
x = self.conv3(self.norm3(x))
x = self.drop_path(x)
return x + shortcut
class Bottleneck(nn.Module):
"""Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT.
"""
def __init__(
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
super().__init__()
first_dilation = first_dilation or dilation
act_layer = act_layer or nn.ReLU
conv_layer = conv_layer or StdConv2d
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
out_chs = out_chs or in_chs
mid_chs = make_div(out_chs * bottle_ratio)
if proj_layer is not None:
self.downsample = proj_layer(
in_chs, out_chs, stride=stride, dilation=dilation, preact=False,
conv_layer=conv_layer, norm_layer=norm_layer)
else:
self.downsample = None
self.conv1 = conv_layer(in_chs, mid_chs, 1)
self.norm1 = norm_layer(mid_chs)
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
self.norm2 = norm_layer(mid_chs)
self.conv3 = conv_layer(mid_chs, out_chs, 1)
self.norm3 = norm_layer(out_chs, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.act3 = act_layer(inplace=True)
def forward(self, x):
# shortcut branch
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x)
# residual
x = self.conv1(x)
x = self.norm1(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.conv3(x)
x = self.norm3(x)
x = self.drop_path(x)
x = self.act3(x + shortcut)
return x
class DownsampleConv(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True,
conv_layer=None, norm_layer=None):
super(DownsampleConv, self).__init__()
self.conv = conv_layer(in_chs, out_chs, 1, stride=stride)
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
def forward(self, x):
return self.norm(self.conv(x))
class DownsampleAvg(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None,
preact=True, conv_layer=None, norm_layer=None):
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
super(DownsampleAvg, self).__init__()
avg_stride = stride if dilation == 1 else 1
if stride > 1 or dilation > 1:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
else:
self.pool = nn.Identity()
self.conv = conv_layer(in_chs, out_chs, 1, stride=1)
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
def forward(self, x):
return self.norm(self.conv(self.pool(x)))
class ResNetStage(nn.Module):
"""ResNet Stage."""
def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1,
avg_down=False, block_dpr=None, block_fn=PreActBottleneck,
act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs):
super(ResNetStage, self).__init__()
first_dilation = 1 if dilation in (1, 2) else 2
layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer)
proj_layer = DownsampleAvg if avg_down else DownsampleConv
prev_chs = in_chs
self.blocks = nn.Sequential()
for block_idx in range(depth):
drop_path_rate = block_dpr[block_idx] if block_dpr else 0.
stride = stride if block_idx == 0 else 1
self.blocks.add_module(str(block_idx), block_fn(
prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups,
first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate,
**layer_kwargs, **block_kwargs))
prev_chs = out_chs
first_dilation = dilation
proj_layer = None
def forward(self, x):
x = self.blocks(x)
return x
def create_stem(in_chs, out_chs, stem_type='', preact=True, conv_layer=None, norm_layer=None):
stem = OrderedDict()
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
# NOTE conv padding mode can be changed by overriding the conv_layer def
if 'deep' in stem_type:
# A 3 deep 3x3 conv stack as in ResNet V1D models
mid_chs = out_chs // 2
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1)
else:
# The usual 7x7 stem conv
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
if not preact:
stem['norm'] = norm_layer(out_chs)
if 'fixed' in stem_type:
# 'fixed' SAME padding approximation that is used in BiT models
stem['pad'] = nn.ConstantPad2d(1, 0.)
stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
elif 'same' in stem_type:
# full, input size based 'SAME' padding, used in ViT Hybrid model
stem['pool'] = create_pool2d('max', kernel_size=3, stride=2, padding='same')
else:
# the usual PyTorch symmetric padding
stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
return nn.Sequential(stem)
class ResNetV2(nn.Module):
"""Implementation of Pre-activation (v2) ResNet mode.
"""
def __init__(self, layers, channels=(256, 512, 1024, 2048),
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
drop_rate=0., drop_path_rate=0.):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
wf = width_factor
self.feature_info = []
stem_chs = make_div(stem_chs * wf)
self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
# NOTE no, reduction 2 feature if preact
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module='' if preact else 'stem.norm'))
prev_chs = stem_chs
curr_stride = 4
dilation = 1
block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
block_fn = PreActBottleneck if preact else Bottleneck
self.stages = nn.Sequential()
for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
out_chs = make_div(c * wf)
stride = 1 if stage_idx == 0 else 2
if curr_stride >= output_stride:
dilation *= stride
stride = 1
stage = ResNetStage(
prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn)
prev_chs = out_chs
curr_stride *= stride
feat_name = f'stages.{stage_idx}'
if preact:
feat_name = f'stages.{stage_idx + 1}.blocks.0.norm1' if (stage_idx + 1) != len(channels) else 'norm'
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=feat_name)]
self.stages.add_module(str(stage_idx), stage)
self.num_features = prev_chs
self.norm = norm_layer(self.num_features) if preact else nn.Identity()
self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
for n, m in self.named_modules():
if isinstance(m, nn.Linear) or ('.fc' in n and isinstance(m, nn.Conv2d)):
nn.init.normal_(m.weight, mean=0.0, std=0.01)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
if not self.head.global_pool.is_identity():
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
return x
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
import numpy as np
weights = np.load(checkpoint_path)
with torch.no_grad():
stem_conv_w = tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])
if self.stem.conv.weight.shape[1] == 1:
self.stem.conv.weight.copy_(stem_conv_w.sum(dim=1, keepdim=True))
# FIXME handle > 3 in_chans?
else:
self.stem.conv.weight.copy_(stem_conv_w)
self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
for i, (sname, stage) in enumerate(self.stages.named_children()):
for j, (bname, block) in enumerate(stage.blocks.named_children()):
convname = 'standardized_conv2d'
block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/'
block.conv1.weight.copy_(tf2th(weights[f'{block_prefix}a/{convname}/kernel']))
block.conv2.weight.copy_(tf2th(weights[f'{block_prefix}b/{convname}/kernel']))
block.conv3.weight.copy_(tf2th(weights[f'{block_prefix}c/{convname}/kernel']))
block.norm1.weight.copy_(tf2th(weights[f'{block_prefix}a/group_norm/gamma']))
block.norm2.weight.copy_(tf2th(weights[f'{block_prefix}b/group_norm/gamma']))
block.norm3.weight.copy_(tf2th(weights[f'{block_prefix}c/group_norm/gamma']))
block.norm1.bias.copy_(tf2th(weights[f'{block_prefix}a/group_norm/beta']))
block.norm2.bias.copy_(tf2th(weights[f'{block_prefix}b/group_norm/beta']))
block.norm3.bias.copy_(tf2th(weights[f'{block_prefix}c/group_norm/beta']))
if block.downsample is not None:
w = weights[f'{block_prefix}a/proj/{convname}/kernel']
block.downsample.conv.weight.copy_(tf2th(w))
def _create_resnetv2(variant, pretrained=False, **kwargs):
# FIXME feature map extraction is not setup properly for pre-activation mode right now
preact = kwargs.get('preact', True)
feature_cfg = dict(flatten_sequential=True)
if preact:
feature_cfg['feature_cls'] = 'hook'
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for preact
return build_model_with_cfg(
ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_custom_load=True,
feature_cfg=feature_cfg, **kwargs)
@register_model
def resnetv2_50x1_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x1_bitm', pretrained=pretrained,
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
@register_model
def resnetv2_50x3_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x3_bitm', pretrained=pretrained,
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
@register_model
def resnetv2_101x1_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x1_bitm', pretrained=pretrained,
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
@register_model
def resnetv2_101x3_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x3_bitm', pretrained=pretrained,
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
@register_model
def resnetv2_152x2_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x2_bitm', pretrained=pretrained,
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
@register_model
def resnetv2_152x4_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x4_bitm', pretrained=pretrained,
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
@register_model
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
@register_model
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
@register_model
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
@register_model
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
@register_model
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
@register_model
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
# NOTE the 'S' versions of the model weights arent as interesting as original 21k or transfer to 1K M.
# @register_model
# def resnetv2_50x1_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_50x1_bits', pretrained=pretrained,
# layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_50x3_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_50x3_bits', pretrained=pretrained,
# layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_101x1_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_101x1_bits', pretrained=pretrained,
# layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_101x3_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_101x3_bits', pretrained=pretrained,
# layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_152x2_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_152x2_bits', pretrained=pretrained,
# layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_152x4_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_152x4_bits', pretrained=pretrained,
# layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
#

@ -5,12 +5,6 @@ A PyTorch implement of Vision Transformers as described in
The official jax code is released and available at https://github.com/google-research/vision_transformer The official jax code is released and available at https://github.com/google-research/vision_transformer
Status/TODO:
* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
Acknowledgments: Acknowledgments:
* The paper authors for releasing code and weights, thanks! * The paper authors for releasing code and weights, thanks!
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
@ -18,18 +12,29 @@ for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
DeiT model defs and weights from https://github.com/facebookresearch/deit,
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import math
import logging
from functools import partial
from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
from functools import partial import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import DropPath, to_2tuple, trunc_normal_ from .layers import DropPath, to_2tuple, trunc_normal_
from .resnet import resnet26d, resnet50d from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, StdConv2dSame
from .registry import register_model from .registry import register_model
_logger = logging.getLogger(__name__)
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
return { return {
@ -43,14 +48,19 @@ def _cfg(url='', **kwargs):
default_cfgs = { default_cfgs = {
# patch models # patch models (my experiments)
'vit_small_patch16_224': _cfg( 'vit_small_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
), ),
# patch models (weights ported from official Google JAX impl)
'vit_base_patch16_224': _cfg( 'vit_base_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
), ),
'vit_base_patch32_224': _cfg(
url='', # no official model weights for this combo, only for in21k
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_base_patch16_384': _cfg( 'vit_base_patch16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
@ -60,19 +70,66 @@ default_cfgs = {
'vit_large_patch16_224': _cfg( 'vit_large_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_large_patch32_224': _cfg(
url='', # no official model weights for this combo, only for in21k
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_large_patch16_384': _cfg( 'vit_large_patch16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
'vit_large_patch32_384': _cfg( 'vit_large_patch32_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
'vit_huge_patch16_224': _cfg(),
'vit_huge_patch32_384': _cfg(input_size=(3, 384, 384)), # patch models, imagenet21k (weights ported from official Google JAX impl)
# hybrid models 'vit_base_patch16_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_base_patch32_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_large_patch16_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_large_patch32_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_huge_patch14_224_in21k': _cfg(
url='', # FIXME I have weights for this but > 2GB limit for github release binaries
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
# hybrid models (weights ported from official Google JAX impl)
'vit_base_resnet50_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'),
'vit_base_resnet50_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'),
# hybrid models (my experiments)
'vit_small_resnet26d_224': _cfg(), 'vit_small_resnet26d_224': _cfg(),
'vit_small_resnet50d_s3_224': _cfg(), 'vit_small_resnet50d_s3_224': _cfg(),
'vit_base_resnet26d_224': _cfg(), 'vit_base_resnet26d_224': _cfg(),
'vit_base_resnet50d_224': _cfg(), 'vit_base_resnet50d_224': _cfg(),
# deit models (FB weights)
'vit_deit_tiny_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
'vit_deit_small_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
'vit_deit_base_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
'vit_deit_base_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_deit_tiny_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'),
'vit_deit_small_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth'),
'vit_deit_base_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', ),
'vit_deit_base_distilled_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
input_size=(3, 384, 384), crop_pct=1.0),
} }
@ -184,32 +241,61 @@ class HybridEmbed(nn.Module):
training = backbone.training training = backbone.training
if training: if training:
backbone.eval() backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
o = o[-1] # last feature if backbone outputs list/tuple of features
feature_size = o.shape[-2:] feature_size = o.shape[-2:]
feature_dim = o.shape[1] feature_dim = o.shape[1]
backbone.train(training) backbone.train(training)
else: else:
feature_size = to_2tuple(feature_size) feature_size = to_2tuple(feature_size)
if hasattr(self.backbone, 'feature_info'):
feature_dim = self.backbone.feature_info.channels()[-1] feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
self.num_patches = feature_size[0] * feature_size[1] self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Linear(feature_dim, embed_dim) self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
def forward(self, x): def forward(self, x):
x = self.backbone(x)[-1] x = self.backbone(x)
x = x.flatten(2).transpose(1, 2) if isinstance(x, (list, tuple)):
x = self.proj(x) x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x).flatten(2).transpose(1, 2)
return x return x
class VisionTransformer(nn.Module): class VisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage """ Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
https://arxiv.org/abs/2010.11929
""" """
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm): drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
norm_layer: (nn.Module): normalization layer
"""
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
if hybrid_backbone is not None: if hybrid_backbone is not None:
self.patch_embed = HybridEmbed( self.patch_embed = HybridEmbed(
@ -231,12 +317,18 @@ class VisionTransformer(nn.Module):
for i in range(depth)]) for i in range(depth)])
self.norm = norm_layer(embed_dim) self.norm = norm_layer(embed_dim)
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here # Representation layer
#self.repr = nn.Linear(embed_dim, representation_size) if representation_size:
#self.repr_act = nn.Tanh() self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(embed_dim, representation_size)),
('act', nn.Tanh())
]))
else:
self.pre_logits = nn.Identity()
# Classifier head # Classifier head
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.cls_token, std=.02)
@ -274,8 +366,9 @@ class VisionTransformer(nn.Module):
for blk in self.blocks: for blk in self.blocks:
x = blk(x) x = blk(x)
x = self.norm(x) x = self.norm(x)[:, 0]
return x[:, 0] x = self.pre_logits(x)
return x
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
@ -283,146 +376,412 @@ class VisionTransformer(nn.Module):
return x return x
def _conv_filter(state_dict, patch_size=16): class DistilledVisionTransformer(VisionTransformer):
""" Vision Transformer with distillation token.
Paper: `Training data-efficient image transformers & distillation through attention` -
https://arxiv.org/abs/2012.12877
This impl of distilled ViT is taken from https://github.com/facebookresearch/deit
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
trunc_normal_(self.dist_token, std=.02)
trunc_normal_(self.pos_embed, std=.02)
self.head_dist.apply(self._init_weights)
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x[:, 0], x[:, 1]
def forward(self, x):
x, x_dist = self.forward_features(x)
x = self.head(x)
x_dist = self.head_dist(x_dist)
if self.training:
return x, x_dist
else:
# during inference, return the average of both classifier predictions
return (x + x_dist) / 2
def resize_pos_embed(posemb, posemb_new):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
ntok_new = posemb_new.shape[1]
if True:
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
ntok_new -= 1
else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
gs_new = int(math.sqrt(ntok_new))
_logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear')
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def checkpoint_filter_fn(state_dict, model):
""" convert patch embedding weight from manual patchify + linear proj to conv""" """ convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {} out_dict = {}
if 'model' in state_dict:
# For deit models
state_dict = state_dict['model']
for k, v in state_dict.items(): for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k: if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
v = v.reshape((v.shape[0], 3, patch_size, patch_size)) # For old models that I trained prior to conv based patchification
O, I, H, W = model.patch_embed.proj.weight.shape
v = v.reshape(O, -1, H, W)
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
# To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(v, model.pos_embed)
out_dict[k] = v out_dict[k] = v
return out_dict return out_dict
def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs):
default_cfg = default_cfgs[variant]
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-1]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
repr_size = kwargs.pop('representation_size', None)
if repr_size is not None and num_classes != default_num_classes:
# Remove representation layer if fine-tuning. This may not always be the desired action,
# but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
_logger.warning("Removing representation layer for fine-tuning.")
repr_size = None
model_cls = DistilledVisionTransformer if distilled else VisionTransformer
model = model_cls(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(
model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3),
filter_fn=partial(checkpoint_filter_fn, model=model))
return model
@register_model @register_model
def vit_small_patch16_224(pretrained=False, **kwargs): def vit_small_patch16_224(pretrained=False, **kwargs):
""" My custom 'small' ViT model. Depth=8, heads=8= mlp_ratio=3."""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
if pretrained: if pretrained:
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
kwargs.setdefault('qk_scale', 768 ** -0.5) model_kwargs.setdefault('qk_scale', 768 ** -0.5)
model = VisionTransformer(patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., **kwargs) model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
model.default_cfg = default_cfgs['vit_small_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
return model return model
@register_model @register_model
def vit_base_patch16_224(pretrained=False, **kwargs): def vit_base_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer( """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) """
model.default_cfg = default_cfgs['vit_base_patch16_224'] model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if pretrained: model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
load_pretrained( return model
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
@register_model
def vit_base_patch32_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
"""
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_base_patch16_384(pretrained=False, **kwargs): def vit_base_patch16_384(pretrained=False, **kwargs):
model = VisionTransformer( """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) """
model.default_cfg = default_cfgs['vit_base_patch16_384'] model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if pretrained: model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model return model
@register_model @register_model
def vit_base_patch32_384(pretrained=False, **kwargs): def vit_base_patch32_384(pretrained=False, **kwargs):
model = VisionTransformer( """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
img_size=384, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) """
model.default_cfg = default_cfgs['vit_base_patch32_384'] model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
if pretrained: model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model return model
@register_model @register_model
def vit_large_patch16_224(pretrained=False, **kwargs): def vit_large_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer( """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) """
model.default_cfg = default_cfgs['vit_large_patch16_224'] model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
if pretrained: model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) return model
@register_model
def vit_large_patch32_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
"""
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_large_patch16_384(pretrained=False, **kwargs): def vit_large_patch16_384(pretrained=False, **kwargs):
model = VisionTransformer( """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) """
model.default_cfg = default_cfgs['vit_large_patch16_384'] model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
if pretrained: model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model return model
@register_model @register_model
def vit_large_patch32_384(pretrained=False, **kwargs): def vit_large_patch32_384(pretrained=False, **kwargs):
model = VisionTransformer( """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
img_size=384, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) """
model.default_cfg = default_cfgs['vit_large_patch32_384'] model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
if pretrained: model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) return model
@register_model
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_huge_patch16_224(pretrained=False, **kwargs): def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
model = VisionTransformer(patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs) """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
model.default_cfg = default_cfgs['vit_huge_patch16_224'] ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_huge_patch32_384(pretrained=False, **kwargs): def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
model = VisionTransformer( """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
img_size=384, patch_size=32, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs) ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
model.default_cfg = default_cfgs['vit_huge_patch32_384'] """
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(
patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: converted weights not currently available, too large for github release hosting.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
backbone = ResNetV2(
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
preact=False, stem_type='same', conv_layer=StdConv2dSame)
model_kwargs = dict(
embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone,
representation_size=768, **kwargs)
model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_resnet50_384(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
backbone = ResNetV2(
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
preact=False, stem_type='same', conv_layer=StdConv2dSame)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_small_resnet26d_224(pretrained=False, **kwargs): def vit_small_resnet26d_224(pretrained=False, **kwargs):
pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
backbone = resnet26d(pretrained=pretrained_backbone, features_only=True, out_indices=[4]) """
model = VisionTransformer( backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
img_size=224, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
model.default_cfg = default_cfgs['vit_small_resnet26d_224'] model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_small_resnet50d_s3_224(pretrained=False, **kwargs): def vit_small_resnet50d_s3_224(pretrained=False, **kwargs):
pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
backbone = resnet50d(pretrained=pretrained_backbone, features_only=True, out_indices=[3]) """
model = VisionTransformer( backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3])
img_size=224, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
model.default_cfg = default_cfgs['vit_small_resnet50d_s3_224'] model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_base_resnet26d_224(pretrained=False, **kwargs): def vit_base_resnet26d_224(pretrained=False, **kwargs):
pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing """ Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights.
backbone = resnet26d(pretrained=pretrained_backbone, features_only=True, out_indices=[4]) """
model = VisionTransformer( backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
img_size=224, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs) model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model.default_cfg = default_cfgs['vit_base_resnet26d_224'] model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_base_resnet50d_224(pretrained=False, **kwargs): def vit_base_resnet50d_224(pretrained=False, **kwargs):
pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
backbone = resnet50d(pretrained=pretrained_backbone, features_only=True, out_indices=[4]) """
model = VisionTransformer( backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
img_size=224, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs) model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model.default_cfg = default_cfgs['vit_base_resnet50d_224'] model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_deit_small_patch16_224(pretrained=False, **kwargs):
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_deit_base_patch16_224(pretrained=False, **kwargs):
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_deit_base_patch16_384(pretrained=False, **kwargs):
""" DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
""" DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer(
'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
return model
@register_model
def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
""" DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer(
'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
return model
@register_model
def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
""" DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer(
'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
return model
@register_model
def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs):
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer(
'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
return model return model

@ -1 +1 @@
__version__ = '0.3.4' __version__ = '0.4.0'

@ -28,7 +28,7 @@ import torch.nn as nn
import torchvision.utils import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model
from timm.utils import * from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
@ -64,8 +64,14 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# Dataset / Model parameters # Dataset / Model parameters
parser.add_argument('data', metavar='DIR', parser.add_argument('data_dir', metavar='DIR',
help='path to dataset') help='path to dataset')
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)')
parser.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)')
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL', parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"') help='Name of model to train (default: "countception"')
parser.add_argument('--pretrained', action='store_true', default=False, parser.add_argument('--pretrained', action='store_true', default=False,
@ -76,8 +82,8 @@ parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)') help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False, parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model') help='prevent resume of optimizer state when resuming model')
parser.add_argument('--num-classes', type=int, default=1000, metavar='N', parser.add_argument('--num-classes', type=int, default=None, metavar='N',
help='number of label classes (default: 1000)') help='number of label classes (Model default if None)')
parser.add_argument('--gp', default=None, type=str, metavar='POOL', parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--img-size', type=int, default=None, metavar='N', parser.add_argument('--img-size', type=int, default=None, metavar='N',
@ -331,6 +337,9 @@ def main():
bn_eps=args.bn_eps, bn_eps=args.bn_eps,
scriptable=args.torchscript, scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint) checkpoint_path=args.initial_checkpoint)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
if args.local_rank == 0: if args.local_rank == 0:
_logger.info('Model %s created, param count: %d' % _logger.info('Model %s created, param count: %d' %
@ -437,19 +446,10 @@ def main():
_logger.info('Scheduled epochs: {}'.format(num_epochs)) _logger.info('Scheduled epochs: {}'.format(num_epochs))
# create the train and eval datasets # create the train and eval datasets
train_dir = os.path.join(args.data, 'train') dataset_train = create_dataset(
if not os.path.exists(train_dir): args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size)
_logger.error('Training folder does not exist at: {}'.format(train_dir)) dataset_eval = create_dataset(
exit(1) args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)
dataset_train = Dataset(train_dir)
eval_dir = os.path.join(args.data, 'val')
if not os.path.isdir(eval_dir):
eval_dir = os.path.join(args.data, 'validation')
if not os.path.isdir(eval_dir):
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
exit(1)
dataset_eval = Dataset(eval_dir)
# setup mixup / cutmix # setup mixup / cutmix
collate_fn = None collate_fn = None
@ -553,10 +553,10 @@ def main():
try: try:
for epoch in range(start_epoch, num_epochs): for epoch in range(start_epoch, num_epochs):
if args.distributed: if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
loader_train.sampler.set_epoch(epoch) loader_train.sampler.set_epoch(epoch)
train_metrics = train_epoch( train_metrics = train_one_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args, epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
@ -594,7 +594,7 @@ def main():
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def train_epoch( def train_one_epoch(
epoch, model, loader, optimizer, loss_fn, args, epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress, lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None): loss_scaler=None, model_ema=None, mixup_fn=None):

@ -20,7 +20,7 @@ from collections import OrderedDict
from contextlib import suppress from contextlib import suppress
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
has_apex = False has_apex = False
@ -44,7 +44,11 @@ _logger = logging.getLogger('validate')
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
parser.add_argument('data', metavar='DIR', parser.add_argument('data', metavar='DIR',
help='path to dataset') help='path to dataset')
parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92', parser.add_argument('--dataset', '-d', metavar='NAME', default='',
help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--split', metavar='NAME', default='validation',
help='dataset split (default: validation)')
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
help='model architecture (default: dpn92)') help='model architecture (default: dpn92)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 2)') help='number of data loading workers (default: 2)')
@ -62,7 +66,7 @@ parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD'
help='Override std deviation of of dataset') help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME', parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)') help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=1000, parser.add_argument('--num-classes', type=int, default=None,
help='Number classes in dataset') help='Number classes in dataset')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")') help='path to class to idx mapping file (default: "")')
@ -133,6 +137,9 @@ def validate(args):
in_chans=3, in_chans=3,
global_pool=args.gp, global_pool=args.gp,
scriptable=args.torchscript) scriptable=args.torchscript)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes
if args.checkpoint: if args.checkpoint:
load_checkpoint(model, args.checkpoint, args.use_ema) load_checkpoint(model, args.checkpoint, args.use_ema)
@ -159,10 +166,9 @@ def validate(args):
criterion = nn.CrossEntropyLoss().cuda() criterion = nn.CrossEntropyLoss().cuda()
if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = create_dataset(
dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) root=args.data, name=args.dataset, split=args.split,
else: load_bytes=args.tf_preprocessing, class_map=args.class_map)
dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
if args.valid_labels: if args.valid_labels:
with open(args.valid_labels, 'r') as f: with open(args.valid_labels, 'r') as f:

Loading…
Cancel
Save