Merge remote-tracking branch 'origin/master' into bits_and_tpu

pull/1414/head
Ross Wightman 2 years ago
commit 1186fc9c73

@ -13,16 +13,50 @@
## Sponsors
A big thank you to my [GitHub Sponsors](https://github.com/sponsors/rwightman) for their support!
In addition to the sponsors at the link above, I've received hardware and/or cloud resources from
Thanks to the following for hardware support:
* TPU Research Cloud (TRC) (https://sites.research.google/trc/about/)
* Nvidia (https://www.nvidia.com/en-us/)
* TFRC (https://www.tensorflow.org/tfrc)
I'm fortunate to be able to dedicate significant time and money of my own supporting this and other open source projects. However, as the projects increase in scope, outside support is needed to continue with the current trajectory of cloud services, hardware, and electricity costs.
And a big thanks to all GitHub sponsors who helped with some of my costs before I joined Hugging Face.
## What's New
### July 8, 2022
More models, more fixes
* Official research models (w/ weights) added:
* EdgeNeXt from (https://github.com/mmaaz60/EdgeNeXt)
* MobileViT-V2 from (https://github.com/apple/ml-cvnets)
* DeiT III (Revenge of the ViT) from (https://github.com/facebookresearch/deit)
* My own models:
* Small `ResNet` defs added by request with 1 block repeats for both basic and bottleneck (resnet10 and resnet14)
* `CspNet` refactored with dataclass config, simplified CrossStage3 (`cs3`) option. These are closer to YOLO-v5+ backbone defs.
* More relative position vit fiddling. Two `srelpos` (shared relative position) models trained, and a medium w/ class token.
* Add an alternate downsample mode to EdgeNeXt and train a `small` model. Better than original small, but not their new USI trained weights.
* My own model weight results (all ImageNet-1k training)
* `resnet10t` - 66.5 @ 176, 68.3 @ 224
* `resnet14t` - 71.3 @ 176, 72.3 @ 224
* `resnetaa50` - 80.6 @ 224 , 81.6 @ 288
* `darknet53` - 80.0 @ 256, 80.5 @ 288
* `cs3darknet_m` - 77.0 @ 256, 77.6 @ 288
* `cs3darknet_focus_m` - 76.7 @ 256, 77.3 @ 288
* `cs3darknet_l` - 80.4 @ 256, 80.9 @ 288
* `cs3darknet_focus_l` - 80.3 @ 256, 80.9 @ 288
* `vit_srelpos_small_patch16_224` - 81.1 @ 224, 82.1 @ 320
* `vit_srelpos_medium_patch16_224` - 82.3 @ 224, 83.1 @ 320
* `vit_relpos_small_patch16_cls_224` - 82.6 @ 224, 83.6 @ 320
* `edgnext_small_rw` - 79.6 @ 224, 80.4 @ 320
* `cs3`, `darknet`, and `vit_*relpos` weights above all trained on TPU thanks to TRC program! Rest trained on overheating GPUs.
* Hugging Face Hub support fixes verified, demo notebook TBA
* Pretrained weights / configs can be loaded externally (ie from local disk) w/ support for head adaptation.
* Add support to change image extensions scanned by `timm` datasets/parsers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103)
* Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases.
* a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges.
* previous impl exists as `LayerNormExp2d` in `models/layers/norm.py`
* Numerous bug fixes
* Currently testing for imminent PyPi 0.6.x release
* LeViT pretraining of larger models still a WIP, they don't train well / easily without distillation. Time to add distill support (finally)?
* ImageNet-22k weight training + finetune ongoing, work on multi-weight support (slowly) chugging along (there are a LOT of weights, sigh) ...
### May 13, 2022
* Official Swin-V2 models and weights added from (https://github.com/microsoft/Swin-Transformer). Cleaned up to support torchscript.
* Some refactoring for existing `timm` Swin-V2-CR impl, will likely do a bit more to bring parts closer to official and decide whether to merge some aspects.
@ -349,6 +383,7 @@ A full version of the list below with source links can be found in the [document
* DenseNet - https://arxiv.org/abs/1608.06993
* DLA - https://arxiv.org/abs/1707.06484
* DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629
* EdgeNeXt - https://arxiv.org/abs/2206.10589
* EfficientNet (MBConvNet Family)
* EfficientNet NoisyStudent (B0-B7, L2) - https://arxiv.org/abs/1911.04252
* EfficientNet AdvProp (B0-B8) - https://arxiv.org/abs/1911.09665

@ -6,24 +6,23 @@ An inference and train step benchmark script for timm models.
Hacked together by Ross Wightman (https://github.com/rwightman)
"""
import argparse
import os
import csv
import json
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
import time
from collections import OrderedDict
from contextlib import suppress
from functools import partial
import torch
import torch.nn as nn
import torch.nn.parallel
from timm.data import resolve_data_config
from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer_v2
from timm.data import resolve_data_config
from timm.utils import setup_default_logging, set_jit_fuser
has_apex = False
try:
from apex import amp
@ -51,6 +50,12 @@ except ImportError as e:
FlopCountAnalysis = None
has_fvcore_profiling = False
try:
from functorch.compile import memory_efficient_fusion
has_functorch = True
except ImportError as e:
has_functorch = False
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate')
@ -65,6 +70,8 @@ parser.add_argument('--bench', default='both', type=str,
help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'")
parser.add_argument('--detail', action='store_true', default=False,
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
parser.add_argument('--no-retry', action='store_true', default=False,
help='Do not decay batch size and retry on error.')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
parser.add_argument('--num-warm-iter', default=10, type=int,
@ -95,10 +102,13 @@ parser.add_argument('--amp', action='store_true', default=False,
help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
parser.add_argument('--precision', default='float32', type=str,
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
# train optimizer parameters
@ -160,10 +170,9 @@ def resolve_precision(precision: str):
def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
macs, _ = get_model_profile(
_, macs, _ = get_model_profile(
model=model,
input_res=(batch_size,) + input_size, # input shape or input to the input_constructor
input_constructor=None, # if specified, a constructor taking input_res is used as input to the model
input_shape=(batch_size,) + input_size, # input shape/resolution
print_profile=detailed, # prints the model graph with the measured profile attached to each module
detailed=detailed, # print the detailed profile
warm_up=10, # the number of warm-ups before measuring the time of each module
@ -188,8 +197,19 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
class BenchmarkRunner:
def __init__(
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32',
fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
self,
model_name,
detail=False,
device='cuda',
torchscript=False,
aot_autograd=False,
precision='float32',
fuser='',
num_warm_iter=10,
num_bench_iter=50,
use_train_size=False,
**kwargs
):
self.model_name = model_name
self.detail = detail
self.device = device
@ -216,15 +236,19 @@ class BenchmarkRunner:
self.num_classes = self.model.num_classes
self.param_count = count_params(self.model)
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count))
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
self.scripted = False
if torchscript:
self.model = torch.jit.script(self.model)
self.scripted = True
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
self.input_size = data_config['input_size']
self.batch_size = kwargs.pop('batch_size', 256)
if aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
self.model = memory_efficient_fusion(self.model)
self.example_inputs = None
self.num_warm_iter = num_warm_iter
self.num_bench_iter = num_bench_iter
@ -243,7 +267,13 @@ class BenchmarkRunner:
class InferenceBenchmarkRunner(BenchmarkRunner):
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
def __init__(
self,
model_name,
device='cuda',
torchscript=False,
**kwargs
):
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
self.model.eval()
@ -312,7 +342,13 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
class TrainBenchmarkRunner(BenchmarkRunner):
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
def __init__(
self,
model_name,
device='cuda',
torchscript=False,
**kwargs
):
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
self.model.train()
@ -479,7 +515,7 @@ def decay_batch_exp(batch_size, factor=0.5, divisor=16):
return max(0, int(out_batch_size))
def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
def _try_run(model_name, bench_fn, bench_kwargs, initial_batch_size, no_batch_size_retry=False):
batch_size = initial_batch_size
results = dict()
error_str = 'Unknown'
@ -494,8 +530,11 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
if 'channels_last' in error_str:
_logger.error(f'{model_name} not supported in channels_last, skipping.')
break
_logger.warning(f'"{error_str}" while running benchmark. Reducing batch size to {batch_size} for retry.')
_logger.error(f'"{error_str}" while running benchmark.')
if no_batch_size_retry:
break
batch_size = decay_batch_exp(batch_size)
_logger.warning(f'Reducing batch size to {batch_size} for retry.')
results['error'] = error_str
return results
@ -537,7 +576,13 @@ def benchmark(args):
model_results = OrderedDict(model=model)
for prefix, bench_fn in zip(prefixes, bench_fns):
run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs)
run_results = _try_run(
model,
bench_fn,
bench_kwargs=bench_kwargs,
initial_batch_size=batch_size,
no_batch_size_retry=args.no_retry,
)
if prefix and 'error' not in run_results:
run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
model_results.update(run_results)

@ -5,6 +5,7 @@ import logging
import torch
import torch.nn as nn
from timm.models.layers import convert_sync_batchnorm
from timm.optim import create_optimizer_v2
from timm.utils import ModelEmaV2
@ -65,7 +66,7 @@ def setup_model_and_optimizer(
dev_env.to_device(model)
if use_syncbn and dev_env.distributed:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = convert_sync_batchnorm(model)
if dev_env.primary:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '

@ -6,7 +6,8 @@ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset
from .loader import create_loader_v2
from .mixup import Mixup, FastCollateMixup
from .parsers import create_parser
from .parsers import create_parser,\
get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
from .real_labels import RealLabelsImagenet
from .transforms import RandomResizedCropAndInterpolation, ToTensor, ToNumpy
from .transforms_factory import create_transform_v2, create_transform

@ -36,27 +36,46 @@ _HPARAMS_DEFAULT = dict(
img_mean=_FILL,
)
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
# favor of the Image.Resampling enum. The top-level resampling attributes will be
# removed in Pillow 10.
if hasattr(Image, "Resampling"):
_RANDOM_INTERPOLATION = (Image.Resampling.BILINEAR, Image.Resampling.BICUBIC)
_DEFAULT_INTERPOLATION = Image.Resampling.BICUBIC
_pil_interpolation_to_str = {
Image.Resampling.NEAREST: 'nearest',
Image.Resampling.BILINEAR: 'bilinear',
Image.Resampling.BICUBIC: 'bicubic',
Image.Resampling.BOX: 'box',
Image.Resampling.HAMMING: 'hamming',
Image.Resampling.LANCZOS: 'lanczos',
}
else:
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
_DEFAULT_INTERPOLATION = Image.BICUBIC
_pil_interpolation_to_str = {
Image.NEAREST: 'nearest',
Image.BILINEAR: 'bilinear',
Image.BICUBIC: 'bicubic',
Image.BOX: 'box',
Image.HAMMING: 'hamming',
Image.LANCZOS: 'lanczos',
}
_str_to_pil_interpolation = {b: a for a, b in _pil_interpolation_to_str.items()}
def _pil_interp(method):
def _convert(m):
if method == 'bicubic':
return Image.BICUBIC
elif method == 'lanczos':
return Image.LANCZOS
elif method == 'hamming':
return Image.HAMMING
else:
return Image.BILINEAR
if isinstance(method, (list, tuple)):
return [_convert(m) if isinstance(m, str) else m for m in method]
return [_str_to_pil_interpolation(m) if isinstance(m, str) else m for m in method]
else:
return _convert(method) if isinstance(method, str) else method
return _str_to_pil_interpolation(method) if isinstance(method, str) else method
def _interpolation(kwargs):
interpolation = kwargs.pop('resample', Image.BILINEAR)
interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION)
if isinstance(interpolation, (list, tuple)):
return random.choice(interpolation)
else:

@ -107,11 +107,15 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v
new_config['std'] = default_cfg['std']
# resolve default crop percentage
new_config['crop_pct'] = DEFAULT_CROP_PCT
crop_pct = DEFAULT_CROP_PCT
if 'crop_pct' in args and args['crop_pct'] is not None:
new_config['crop_pct'] = args['crop_pct']
elif 'crop_pct' in default_cfg:
new_config['crop_pct'] = default_cfg['crop_pct']
crop_pct = args['crop_pct']
else:
if use_test_size and 'test_crop_pct' in default_cfg:
crop_pct = default_cfg['test_crop_pct']
elif 'crop_pct' in default_cfg:
crop_pct = default_cfg['crop_pct']
new_config['crop_pct'] = crop_pct
if getattr(args, 'mixup', 0) > 0 \
or getattr(args, 'cutmix', 0) > 0. \

@ -26,8 +26,8 @@ _TORCH_BASIC_DS = dict(
kmnist=KMNIST,
fashion_mnist=FashionMNIST,
)
_TRAIN_SYNONYM = {'train', 'training'}
_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'}
_TRAIN_SYNONYM = dict(train=None, training=None)
_EVAL_SYNONYM = dict(val=None, valid=None, validation=None, eval=None, evaluation=None)
def _search_split(root, split):

@ -1 +1,2 @@
from .parser_factory import create_parser
from .img_extensions import *

@ -1 +0,0 @@
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg')

@ -0,0 +1,50 @@
from copy import deepcopy
__all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions']
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use
_IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync
def _set_extensions(extensions):
global IMG_EXTENSIONS
global _IMG_EXTENSIONS_SET
dedupe = set() # NOTE de-duping tuple while keeping original order
IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x))
_IMG_EXTENSIONS_SET = set(extensions)
def _valid_extension(x: str):
return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.')
def is_img_extension(ext):
return ext in _IMG_EXTENSIONS_SET
def get_img_extensions(as_set=False):
return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS)
def set_img_extensions(extensions):
assert len(extensions)
for x in extensions:
assert _valid_extension(x)
_set_extensions(extensions)
def add_img_extensions(ext):
if not isinstance(ext, (list, tuple, set)):
ext = (ext,)
for x in ext:
assert _valid_extension(x)
extensions = IMG_EXTENSIONS + tuple(ext)
_set_extensions(extensions)
def del_img_extensions(ext):
if not isinstance(ext, (list, tuple, set)):
ext = (ext,)
extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext)
_set_extensions(extensions)

@ -1,7 +1,6 @@
import os
from .parser_image_folder import ParserImageFolder
from .parser_image_tar import ParserImageTar
from .parser_image_in_tar import ParserImageInTar

@ -6,15 +6,35 @@ on the folder hierarchy, just leaf folders by default.
Hacked together by / Copyright 2020 Ross Wightman
"""
import os
from typing import Dict, List, Optional, Set, Tuple, Union
from timm.utils.misc import natural_key
from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
from .img_extensions import get_img_extensions
from .parser import Parser
def find_images_and_targets(
folder: str,
types: Optional[Union[List, Tuple, Set]] = None,
class_to_idx: Optional[Dict] = None,
leaf_name_only: bool = True,
sort: bool = True
):
""" Walk folder recursively to discover images and map them to classes by folder names.
Args:
folder: root of folder to recrusively search
types: types (file extensions) to search for in path
class_to_idx: specify mapping for class (folder name) to class index if set
leaf_name_only: use only leaf-name of folder walk for class names
sort: re-sort found images by name (for consistent ordering)
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
Returns:
A list of image and target tuples, class_to_idx mapping
"""
types = get_img_extensions(as_set=True) if not types else set(types)
labels = []
filenames = []
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
@ -51,7 +71,8 @@ class ParserImageFolder(Parser):
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
if len(self.samples) == 0:
raise RuntimeError(
f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
f'Found 0 images in subfolders of {root}. '
f'Supported image extensions are {", ".join(get_img_extensions())}')
def __getitem__(self, index):
path, target = self.samples[index]

@ -9,20 +9,20 @@ Labels are based on the combined folder and/or tar name structure.
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import os
import tarfile
import pickle
import logging
import numpy as np
import tarfile
from glob import glob
from typing import List, Dict
from typing import List, Tuple, Dict, Set, Optional, Union
import numpy as np
from timm.utils.misc import natural_key
from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
from .img_extensions import get_img_extensions
from .parser import Parser
_logger = logging.getLogger(__name__)
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
@ -39,7 +39,7 @@ class TarState:
self.tf = None
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS):
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions: Set[str]):
sample_count = 0
for i, ti in enumerate(tf):
if not ti.isfile():
@ -60,7 +60,14 @@ def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTE
return sample_count
def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True):
def extract_tarinfos(
root,
class_name_to_idx: Optional[Dict] = None,
cache_tarinfo: Optional[bool] = None,
extensions: Optional[Union[List, Tuple, Set]] = None,
sort: bool = True
):
extensions = get_img_extensions(as_set=True) if not extensions else set(extensions)
root_is_tar = False
if os.path.isfile(root):
assert os.path.splitext(root)[-1].lower() == '.tar'
@ -176,8 +183,8 @@ class ParserImageInTar(Parser):
self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
self.root,
class_name_to_idx=class_name_to_idx,
cache_tarinfo=cache_tarinfo,
extensions=IMG_EXTENSIONS)
cache_tarinfo=cache_tarinfo
)
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
if len(tarfiles) == 1 and tarfiles[0][0] is None:
self.root_is_tar = True

@ -8,13 +8,15 @@ Hacked together by / Copyright 2020 Ross Wightman
import os
import tarfile
from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
from timm.utils.misc import natural_key
from .class_map import load_class_map
from .img_extensions import get_img_extensions
from .parser import Parser
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
extensions = get_img_extensions(as_set=True)
files = []
labels = []
for ti in tarfile.getmembers():
@ -23,7 +25,7 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
dirname, basename = os.path.split(ti.path)
label = os.path.basename(dirname)
ext = os.path.splitext(basename)[1]
if ext.lower() in IMG_EXTENSIONS:
if ext.lower() in extensions:
files.append(ti)
labels.append(label)
if class_to_idx is None:

@ -2,7 +2,6 @@ import torch
import torchvision.transforms.functional as F
from torchvision.transforms import InterpolationMode
from PIL import Image
import warnings
import math
import random

@ -12,6 +12,7 @@ from .deit import *
from .densenet import *
from .dla import *
from .dpn import *
from .edgenext import *
from .efficientnet import *
from .ghostnet import *
from .gluon_resnet import *
@ -61,7 +62,7 @@ from .xcit import *
from .factory import create_model, parse_model_name, safe_model_name
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model
from .layers import convert_splitbn_model, convert_sync_batchnorm
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value

@ -17,9 +17,8 @@ import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_module
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
from .layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d
from .registry import register_model
@ -44,6 +43,7 @@ default_cfgs = dict(
convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
convnext_nano_hnf=_cfg(url=''),
convnext_nano_ols=_cfg(url=''),
convnext_tiny_hnf=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
crop_pct=0.95),
@ -88,35 +88,6 @@ default_cfgs = dict(
)
def _is_contiguous(tensor: torch.Tensor) -> bool:
# jit is oh so lovely :/
# if torch.jit.is_tracing():
# return True
if torch.jit.is_scripting():
return tensor.is_contiguous()
else:
return tensor.is_contiguous(memory_format=torch.contiguous_format)
@register_notrace_module
class LayerNorm2d(nn.LayerNorm):
r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
"""
def __init__(self, normalized_shape, eps=1e-6):
super().__init__(normalized_shape, eps=eps)
def forward(self, x) -> torch.Tensor:
if _is_contiguous(x):
return F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
else:
s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
x = (x - u) * torch.rsqrt(s + self.eps)
x = x * self.weight[:, None, None] + self.bias[:, None, None]
return x
class ConvNeXtBlock(nn.Module):
""" ConvNeXt Block
There are two equivalent implementations:
@ -133,16 +104,30 @@ class ConvNeXtBlock(nn.Module):
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""
def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=False, mlp_ratio=4, norm_layer=None):
def __init__(
self,
dim,
dim_out=None,
stride=1,
mlp_ratio=4,
conv_mlp=False,
conv_bias=True,
ls_init_value=1e-6,
norm_layer=None,
act_layer=nn.GELU,
drop_path=0.,
):
super().__init__()
dim_out = dim_out or dim
if not norm_layer:
norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
mlp_layer = ConvMlp if conv_mlp else Mlp
self.use_conv_mlp = conv_mlp
self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.norm = norm_layer(dim)
self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=nn.GELU)
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None
self.conv_dw = create_conv2d(dim, dim_out, kernel_size=7, stride=stride, depthwise=True, bias=conv_bias)
self.norm = norm_layer(dim_out)
self.mlp = mlp_layer(dim_out, int(mlp_ratio * dim_out), act_layer=act_layer)
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
@ -158,6 +143,7 @@ class ConvNeXtBlock(nn.Module):
x = x.permute(0, 3, 1, 2)
if self.gamma is not None:
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
x = self.drop_path(x) + shortcut
return x
@ -165,25 +151,44 @@ class ConvNeXtBlock(nn.Module):
class ConvNeXtStage(nn.Module):
def __init__(
self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=False,
norm_layer=None, cl_norm_layer=None, cross_stage=False):
self,
in_chs,
out_chs,
stride=2,
depth=2,
drop_path_rates=None,
ls_init_value=1.0,
conv_mlp=False,
conv_bias=True,
norm_layer=None,
norm_layer_cl=None
):
super().__init__()
self.grad_checkpointing = False
if in_chs != out_chs or stride > 1:
self.downsample = nn.Sequential(
norm_layer(in_chs),
nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride),
nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride, bias=conv_bias),
)
in_chs = out_chs
else:
self.downsample = nn.Identity()
dp_rates = dp_rates or [0.] * depth
self.blocks = nn.Sequential(*[ConvNeXtBlock(
dim=out_chs, drop_path=dp_rates[j], ls_init_value=ls_init_value, conv_mlp=conv_mlp,
norm_layer=norm_layer if conv_mlp else cl_norm_layer)
for j in range(depth)]
)
drop_path_rates = drop_path_rates or [0.] * depth
stage_blocks = []
for i in range(depth):
stage_blocks.append(ConvNeXtBlock(
dim=in_chs,
dim_out=out_chs,
drop_path=drop_path_rates[i],
ls_init_value=ls_init_value,
conv_mlp=conv_mlp,
conv_bias=conv_bias,
norm_layer=norm_layer if conv_mlp else norm_layer_cl
))
in_chs = out_chs
self.blocks = nn.Sequential(*stage_blocks)
def forward(self, x):
x = self.downsample(x)
@ -210,41 +215,56 @@ class ConvNeXt(nn.Module):
"""
def __init__(
self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4,
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=False, stem_type='patch',
head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0.,
self,
in_chans=3,
num_classes=1000,
global_pool='avg',
output_stride=32,
depths=(3, 3, 9, 3),
dims=(96, 192, 384, 768),
ls_init_value=1e-6,
stem_type='patch',
stem_kernel_size=4,
stem_stride=4,
head_init_scale=1.,
head_norm_first=False,
conv_mlp=False,
conv_bias=True,
norm_layer=None,
drop_rate=0.,
drop_path_rate=0.,
):
super().__init__()
assert output_stride == 32
if norm_layer is None:
norm_layer = partial(LayerNorm2d, eps=1e-6)
cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
else:
assert conv_mlp,\
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
cl_norm_layer = norm_layer
norm_layer_cl = norm_layer
self.num_classes = num_classes
self.drop_rate = drop_rate
self.feature_info = []
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
assert stem_type in ('patch', 'overlap')
if stem_type == 'patch':
assert stem_kernel_size == stem_stride
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size),
nn.Conv2d(in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, bias=conv_bias),
norm_layer(dims[0])
)
curr_stride = patch_size
prev_chs = dims[0]
else:
self.stem = nn.Sequential(
nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1),
norm_layer(32),
nn.GELU(),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.Conv2d(
in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride,
padding=stem_kernel_size // 2, bias=conv_bias),
norm_layer(dims[0]),
)
curr_stride = 2
prev_chs = 64
prev_chs = dims[0]
curr_stride = stem_stride
self.stages = nn.Sequential()
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
@ -256,16 +276,23 @@ class ConvNeXt(nn.Module):
curr_stride *= stride
out_chs = dims[i]
stages.append(ConvNeXtStage(
prev_chs, out_chs, stride=stride,
depth=depths[i], dp_rates=dp_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp,
norm_layer=norm_layer, cl_norm_layer=cl_norm_layer)
)
prev_chs,
out_chs,
stride=stride,
depth=depths[i],
drop_path_rates=dp_rates[i],
ls_init_value=ls_init_value,
conv_mlp=conv_mlp,
conv_bias=conv_bias,
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl
))
prev_chs = out_chs
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.num_features = prev_chs
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
@ -327,10 +354,11 @@ class ConvNeXt(nn.Module):
def _init_weights(module, name=None, head_init_scale=1.0):
if isinstance(module, nn.Conv2d):
trunc_normal_(module.weight, std=.02)
nn.init.constant_(module.bias, 0)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02)
nn.init.constant_(module.bias, 0)
nn.init.zeros_(module.bias)
if name and 'head.' in name:
module.weight.data.mul_(head_init_scale)
module.bias.data.mul_(head_init_scale)
@ -371,14 +399,25 @@ def _create_convnext(variant, pretrained=False, **kwargs):
@register_model
def convnext_nano_hnf(pretrained=False, **kwargs):
model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs)
model_args = dict(
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs)
model = _create_convnext('convnext_nano_hnf', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_nano_ols(pretrained=False, **kwargs):
model_args = dict(
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True,
conv_bias=False, stem_type='overlap', stem_kernel_size=9, **kwargs)
model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_tiny_hnf(pretrained=False, **kwargs):
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs)
model_args = dict(
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs)
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
return model
@ -386,7 +425,7 @@ def convnext_tiny_hnf(pretrained=False, **kwargs):
@register_model
def convnext_tiny_hnfd(pretrained=False, **kwargs):
model_args = dict(
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, stem_type='dual', **kwargs)
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs)
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
return model

File diff suppressed because it is too large Load Diff

@ -1,12 +1,17 @@
""" DeiT - Data-efficient Image Transformers
DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
paper: `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
paper: `DeiT III: Revenge of the ViT` - https://arxiv.org/abs/2204.07118
Modifications copyright 2021, Ross Wightman
"""
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
from functools import partial
import torch
from torch import nn as nn
@ -53,6 +58,46 @@ default_cfgs = {
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
input_size=(3, 384, 384), crop_pct=1.0,
classifier=('head', 'head_dist')),
'deit3_small_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'),
'deit3_small_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit3_base_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'),
'deit3_base_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit3_large_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'),
'deit3_large_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit3_huge_patch14_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'),
'deit3_small_patch16_224_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth',
crop_pct=1.0),
'deit3_small_patch16_384_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit3_base_patch16_224_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth',
crop_pct=1.0),
'deit3_base_patch16_384_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit3_large_patch16_224_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth',
crop_pct=1.0),
'deit3_large_patch16_384_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit3_huge_patch14_224_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth',
crop_pct=1.0),
}
@ -68,9 +113,10 @@ class VisionTransformerDistilled(VisionTransformer):
super().__init__(*args, **kwargs, weight_init='skip')
assert self.global_pool in ('token',)
self.num_tokens = 2
self.num_prefix_tokens = 2
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
self.distilled_training = False # must set this True to train w/ distillation token
@ -133,7 +179,7 @@ def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
model_cls = VisionTransformerDistilled if distilled else VisionTransformer
model = build_model_with_cfg(
model_cls, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
**kwargs)
return model
@ -220,3 +266,157 @@ def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
model = _create_deit(
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
return model
@register_model
def deit3_small_patch16_224(pretrained=False, **kwargs):
""" DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_small_patch16_384(pretrained=False, **kwargs):
""" DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_small_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_base_patch16_224(pretrained=False, **kwargs):
""" DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_base_patch16_384(pretrained=False, **kwargs):
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_base_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_large_patch16_224(pretrained=False, **kwargs):
""" DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_large_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_large_patch16_384(pretrained=False, **kwargs):
""" DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_large_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_huge_patch14_224(pretrained=False, **kwargs):
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_huge_patch14_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_small_patch16_224_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_small_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_small_patch16_384_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_small_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_base_patch16_224_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_base_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_base_patch16_384_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_base_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_large_patch16_224_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_large_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_large_patch16_384_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_large_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_huge_patch14_224_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_huge_patch14_224_in21ft1k', pretrained=pretrained, **model_kwargs)
return model

@ -0,0 +1,559 @@
""" EdgeNeXt
Paper: `EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications`
- https://arxiv.org/abs/2206.10589
Original code and weights from https://github.com/mmaaz60/EdgeNeXt
Modifications and additions for timm by / Copyright 2022, Ross Wightman
"""
import math
import torch
from collections import OrderedDict
from functools import partial
from typing import Tuple
from torch import nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_module
from .layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
from .registry import register_model
__all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
'crop_pct': 0.9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = dict(
edgenext_xx_small=_cfg(
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth",
test_input_size=(3, 288, 288), test_crop_pct=1.0),
edgenext_x_small=_cfg(
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth",
test_input_size=(3, 288, 288), test_crop_pct=1.0),
# edgenext_small=_cfg(
# url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_small.pth"),
edgenext_small=_cfg( # USI weights
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth",
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
),
edgenext_small_rw=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth',
test_input_size=(3, 320, 320), test_crop_pct=1.0,
),
)
@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
class PositionalEncodingFourier(nn.Module):
def __init__(self, hidden_dim=32, dim=768, temperature=10000):
super().__init__()
self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1)
self.scale = 2 * math.pi
self.temperature = temperature
self.hidden_dim = hidden_dim
self.dim = dim
def forward(self, shape: Tuple[int, int, int]):
inv_mask = ~torch.zeros(shape).to(device=self.token_projection.weight.device, dtype=torch.bool)
y_embed = inv_mask.cumsum(1, dtype=torch.float32)
x_embed = inv_mask.cumsum(2, dtype=torch.float32)
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=inv_mask.device)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(),
pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(),
pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
pos = self.token_projection(pos)
return pos
class ConvBlock(nn.Module):
def __init__(
self,
dim,
dim_out=None,
kernel_size=7,
stride=1,
conv_bias=True,
expand_ratio=4,
ls_init_value=1e-6,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU, drop_path=0.,
):
super().__init__()
dim_out = dim_out or dim
self.shortcut_after_dw = stride > 1 or dim != dim_out
self.conv_dw = create_conv2d(
dim, dim_out, kernel_size=kernel_size, stride=stride, depthwise=True, bias=conv_bias)
self.norm = norm_layer(dim_out)
self.mlp = Mlp(dim_out, int(expand_ratio * dim_out), act_layer=act_layer)
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x
x = self.conv_dw(x)
if self.shortcut_after_dw:
shortcut = x
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.mlp(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = shortcut + self.drop_path(x)
return x
class CrossCovarianceAttn(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.
):
super().__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
q, k, v = qkv.unbind(0)
# NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@torch.jit.ignore
def no_weight_decay(self):
return {'temperature'}
class SplitTransposeBlock(nn.Module):
def __init__(
self,
dim,
num_scales=1,
num_heads=8,
expand_ratio=4,
use_pos_emb=True,
conv_bias=True,
qkv_bias=True,
ls_init_value=1e-6,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
drop_path=0.,
attn_drop=0.,
proj_drop=0.
):
super().__init__()
width = max(int(math.ceil(dim / num_scales)), int(math.floor(dim // num_scales)))
self.width = width
self.num_scales = max(1, num_scales - 1)
convs = []
for i in range(self.num_scales):
convs.append(create_conv2d(width, width, kernel_size=3, depthwise=True, bias=conv_bias))
self.convs = nn.ModuleList(convs)
self.pos_embd = None
if use_pos_emb:
self.pos_embd = PositionalEncodingFourier(dim=dim)
self.norm_xca = norm_layer(dim)
self.gamma_xca = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None
self.xca = CrossCovarianceAttn(
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop)
self.norm = norm_layer(dim, eps=1e-6)
self.mlp = Mlp(dim, int(expand_ratio * dim), act_layer=act_layer)
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x
# scales code re-written for torchscript as per my res2net fixes -rw
spx = torch.split(x, self.width, 1)
spo = []
sp = spx[0]
for i, conv in enumerate(self.convs):
if i > 0:
sp = sp + spx[i]
sp = conv(sp)
spo.append(sp)
spo.append(spx[-1])
x = torch.cat(spo, 1)
# XCA
B, C, H, W = x.shape
x = x.reshape(B, C, H * W).permute(0, 2, 1)
if self.pos_embd is not None:
pos_encoding = self.pos_embd((B, H, W)).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
x = x + pos_encoding
x = x + self.drop_path(self.gamma_xca * self.xca(self.norm_xca(x)))
x = x.reshape(B, H, W, C)
# Inverted Bottleneck
x = self.norm(x)
x = self.mlp(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = shortcut + self.drop_path(x)
return x
class EdgeNeXtStage(nn.Module):
def __init__(
self,
in_chs,
out_chs,
stride=2,
depth=2,
num_global_blocks=1,
num_heads=4,
scales=2,
kernel_size=7,
expand_ratio=4,
use_pos_emb=False,
downsample_block=False,
conv_bias=True,
ls_init_value=1.0,
drop_path_rates=None,
norm_layer=LayerNorm2d,
norm_layer_cl=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU
):
super().__init__()
self.grad_checkpointing = False
if downsample_block or stride == 1:
self.downsample = nn.Identity()
else:
self.downsample = nn.Sequential(
norm_layer(in_chs),
nn.Conv2d(in_chs, out_chs, kernel_size=2, stride=2, bias=conv_bias)
)
in_chs = out_chs
stage_blocks = []
for i in range(depth):
if i < depth - num_global_blocks:
stage_blocks.append(
ConvBlock(
dim=in_chs,
dim_out=out_chs,
stride=stride if downsample_block and i == 0 else 1,
conv_bias=conv_bias,
kernel_size=kernel_size,
expand_ratio=expand_ratio,
ls_init_value=ls_init_value,
drop_path=drop_path_rates[i],
norm_layer=norm_layer_cl,
act_layer=act_layer,
)
)
else:
stage_blocks.append(
SplitTransposeBlock(
dim=in_chs,
num_scales=scales,
num_heads=num_heads,
expand_ratio=expand_ratio,
use_pos_emb=use_pos_emb,
conv_bias=conv_bias,
ls_init_value=ls_init_value,
drop_path=drop_path_rates[i],
norm_layer=norm_layer_cl,
act_layer=act_layer,
)
)
in_chs = out_chs
self.blocks = nn.Sequential(*stage_blocks)
def forward(self, x):
x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
class EdgeNeXt(nn.Module):
def __init__(
self,
in_chans=3,
num_classes=1000,
global_pool='avg',
dims=(24, 48, 88, 168),
depths=(3, 3, 9, 3),
global_block_counts=(0, 1, 1, 1),
kernel_sizes=(3, 5, 7, 9),
heads=(8, 8, 8, 8),
d2_scales=(2, 2, 3, 4),
use_pos_emb=(False, True, False, False),
ls_init_value=1e-6,
head_init_scale=1.,
expand_ratio=4,
downsample_block=False,
conv_bias=True,
stem_type='patch',
head_norm_first=False,
act_layer=nn.GELU,
drop_path_rate=0.,
drop_rate=0.,
):
super().__init__()
self.num_classes = num_classes
self.global_pool = global_pool
self.drop_rate = drop_rate
norm_layer = partial(LayerNorm2d, eps=1e-6)
norm_layer_cl = partial(nn.LayerNorm, eps=1e-6)
self.feature_info = []
assert stem_type in ('patch', 'overlap')
if stem_type == 'patch':
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, bias=conv_bias),
norm_layer(dims[0]),
)
else:
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=9, stride=4, padding=9 // 2, bias=conv_bias),
norm_layer(dims[0]),
)
curr_stride = 4
stages = []
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
in_chs = dims[0]
for i in range(4):
stride = 2 if curr_stride == 2 or i > 0 else 1
# FIXME support dilation / output_stride
curr_stride *= stride
stages.append(EdgeNeXtStage(
in_chs=in_chs,
out_chs=dims[i],
stride=stride,
depth=depths[i],
num_global_blocks=global_block_counts[i],
num_heads=heads[i],
drop_path_rates=dp_rates[i],
scales=d2_scales[i],
expand_ratio=expand_ratio,
kernel_size=kernel_sizes[i],
use_pos_emb=use_pos_emb[i],
ls_init_value=ls_init_value,
downsample_block=downsample_block,
conv_bias=conv_bias,
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl,
act_layer=act_layer,
))
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
in_chs = dims[i]
self.feature_info += [dict(num_chs=in_chs, reduction=curr_stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.num_features = dims[-1]
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem',
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+)\.downsample', (0,)), # blocks
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^norm_pre', (99999,))
]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes=0, global_pool=None):
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
x = self.norm_pre(x)
return x
def forward_head(self, x, pre_logits: bool = False):
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
x = self.head.global_pool(x)
x = self.head.norm(x)
x = self.head.flatten(x)
x = self.head.drop(x)
return x if pre_logits else self.head.fc(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _init_weights(module, name=None, head_init_scale=1.0):
if isinstance(module, nn.Conv2d):
trunc_normal_tf_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Linear):
trunc_normal_tf_(module.weight, std=.02)
nn.init.zeros_(module.bias)
if name and 'head.' in name:
module.weight.data.mul_(head_init_scale)
module.bias.data.mul_(head_init_scale)
def checkpoint_filter_fn(state_dict, model):
""" Remap FB checkpoints -> timm """
if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
return state_dict # non-FB checkpoint
# models were released as train checkpoints... :/
if 'model_ema' in state_dict:
state_dict = state_dict['model_ema']
elif 'model' in state_dict:
state_dict = state_dict['model']
elif 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
out_dict = {}
import re
for k, v in state_dict.items():
k = k.replace('downsample_layers.0.', 'stem.')
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
k = k.replace('dwconv', 'conv_dw')
k = k.replace('pwconv', 'mlp.fc')
k = k.replace('head.', 'head.fc.')
if k.startswith('norm.'):
k = k.replace('norm', 'head.norm')
if v.ndim == 2 and 'head' not in k:
model_shape = model.state_dict()[k].shape
v = v.reshape(model_shape)
out_dict[k] = v
return out_dict
def _create_edgenext(variant, pretrained=False, **kwargs):
model = build_model_with_cfg(
EdgeNeXt, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
**kwargs)
return model
@register_model
def edgenext_xx_small(pretrained=False, **kwargs):
# 1.33M & 260.58M @ 256 resolution
# 71.23% Top-1 accuracy
# No AA, Color Jitter=0.4, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler
# Jetson FPS=51.66 versus 47.67 for MobileViT_XXS
# For A100: FPS @ BS=1: 212.13 & @ BS=256: 7042.06 versus FPS @ BS=1: 96.68 & @ BS=256: 4624.71 for MobileViT_XXS
model_kwargs = dict(depths=(2, 2, 6, 2), dims=(24, 48, 88, 168), heads=(4, 4, 4, 4), **kwargs)
return _create_edgenext('edgenext_xx_small', pretrained=pretrained, **model_kwargs)
@register_model
def edgenext_x_small(pretrained=False, **kwargs):
# 2.34M & 538.0M @ 256 resolution
# 75.00% Top-1 accuracy
# No AA, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler
# Jetson FPS=31.61 versus 28.49 for MobileViT_XS
# For A100: FPS @ BS=1: 179.55 & @ BS=256: 4404.95 versus FPS @ BS=1: 94.55 & @ BS=256: 2361.53 for MobileViT_XS
model_kwargs = dict(depths=(3, 3, 9, 3), dims=(32, 64, 100, 192), heads=(4, 4, 4, 4), **kwargs)
return _create_edgenext('edgenext_x_small', pretrained=pretrained, **model_kwargs)
@register_model
def edgenext_small(pretrained=False, **kwargs):
# 5.59M & 1260.59M @ 256 resolution
# 79.43% Top-1 accuracy
# AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
# Jetson FPS=20.47 versus 18.86 for MobileViT_S
# For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
model_kwargs = dict(depths=(3, 3, 9, 3), dims=(48, 96, 160, 304), **kwargs)
return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs)
@register_model
def edgenext_small_rw(pretrained=False, **kwargs):
# 5.59M & 1260.59M @ 256 resolution
# 79.43% Top-1 accuracy
# AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
# Jetson FPS=20.47 versus 18.86 for MobileViT_S
# For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
model_kwargs = dict(
depths=(3, 3, 9, 3), dims=(48, 96, 192, 384),
downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs)
return _create_edgenext('edgenext_small_rw', pretrained=pretrained, **model_kwargs)

@ -455,18 +455,26 @@ def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter):
filter_kwargs(kwargs, names=kwargs_filter)
def resolve_pretrained_cfg(variant: str, pretrained_cfg=None, kwargs=None):
def resolve_pretrained_cfg(variant: str, pretrained_cfg=None):
if pretrained_cfg and isinstance(pretrained_cfg, dict):
# highest priority, pretrained_cfg available and passed explicitly
# highest priority, pretrained_cfg available and passed as arg
return deepcopy(pretrained_cfg)
if kwargs and 'pretrained_cfg' in kwargs:
# next highest, pretrained_cfg in a kwargs dict, pop and return
pretrained_cfg = kwargs.pop('pretrained_cfg', {})
if pretrained_cfg:
return deepcopy(pretrained_cfg)
# lookup pretrained cfg in model registry by variant
# fallback to looking up pretrained cfg in model registry by variant identifier
pretrained_cfg = get_pretrained_cfg(variant)
assert pretrained_cfg
if not pretrained_cfg:
_logger.warning(
f"No pretrained configuration specified for {variant} model. Using a default."
f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
pretrained_cfg = dict(
url='',
num_classes=1000,
input_size=(3, 224, 224),
pool_size=None,
crop_pct=.9,
interpolation='bicubic',
first_conv='',
classifier='',
)
return pretrained_cfg

@ -428,7 +428,7 @@ class InceptionV3Aux(InceptionV3):
def _create_inception_v3(variant, pretrained=False, **kwargs):
pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
aux_logits = kwargs.pop('aux_logits', False)
if aux_logits:
assert not kwargs.pop('features_only', False)

@ -25,8 +25,8 @@ from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct
from .norm import GroupNorm, GroupNorm1, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed
from .pool2d_same import AvgPool2dSame, create_pool2d
@ -39,4 +39,4 @@ from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .trace_utils import _assert, _float_to_int
from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_

@ -2,6 +2,7 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import functools
from torch import nn as nn
from .create_conv2d import create_conv2d
@ -40,12 +41,26 @@ class ConvNormAct(nn.Module):
ConvBnAct = ConvNormAct
def create_aa(aa_layer, channels, stride=2, enable=True):
if not aa_layer or not enable:
return nn.Identity()
if isinstance(aa_layer, functools.partial):
if issubclass(aa_layer.func, nn.AvgPool2d):
return aa_layer()
else:
return aa_layer(channels)
elif issubclass(aa_layer, nn.AvgPool2d):
return aa_layer(stride)
else:
return aa_layer(channels=channels, stride=stride)
class ConvNormActAa(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None):
super(ConvNormActAa, self).__init__()
use_aa = aa_layer is not None
use_aa = aa_layer is not None and stride == 2
self.conv = create_conv2d(
in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
@ -56,7 +71,7 @@ class ConvNormActAa(nn.Module):
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else nn.Identity()
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
@property
def in_channels(self):

@ -22,7 +22,7 @@ def get_attn(attn_type):
if isinstance(attn_type, torch.nn.Module):
return attn_type
module_cls = None
if attn_type is not None:
if attn_type:
if isinstance(attn_type, str):
attn_type = attn_type.lower()
# Lightweight attention modules (channel and/or coarse spatial).

@ -164,3 +164,6 @@ class DropPath(nn.Module):
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f'drop_prob={round(self.drop_prob,3):0.3f}'

@ -256,8 +256,9 @@ class EvoNorm2dS0a(EvoNorm2dS0):
class EvoNorm2dS1(nn.Module):
def __init__(
self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
apply_act=True, act_layer=None, eps=1e-5, **_):
super().__init__()
act_layer = act_layer or nn.SiLU
self.apply_act = apply_act # apply activation (non-linearity)
if act_layer is not None and apply_act:
self.act = create_act_layer(act_layer)
@ -290,7 +291,7 @@ class EvoNorm2dS1(nn.Module):
class EvoNorm2dS1a(EvoNorm2dS1):
def __init__(
self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_):
apply_act=True, act_layer=None, eps=1e-3, **_):
super().__init__(
num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)
@ -305,8 +306,9 @@ class EvoNorm2dS1a(EvoNorm2dS1):
class EvoNorm2dS2(nn.Module):
def __init__(
self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
apply_act=True, act_layer=None, eps=1e-5, **_):
super().__init__()
act_layer = act_layer or nn.SiLU
self.apply_act = apply_act # apply activation (non-linearity)
if act_layer is not None and apply_act:
self.act = create_act_layer(act_layer)
@ -338,7 +340,7 @@ class EvoNorm2dS2(nn.Module):
class EvoNorm2dS2a(EvoNorm2dS2):
def __init__(
self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_):
apply_act=True, act_layer=None, eps=1e-3, **_):
super().__init__(
num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)

@ -14,11 +14,59 @@ class GroupNorm(nn.GroupNorm):
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
class GroupNorm1(nn.GroupNorm):
""" Group Normalization with 1 group.
Input: tensor in shape [B, C, *]
"""
def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
class LayerNorm2d(nn.LayerNorm):
""" LayerNorm for channels of '2D' spatial BCHW tensors """
def __init__(self, num_channels):
super().__init__(num_channels)
""" LayerNorm for channels of '2D' spatial NCHW tensors """
def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
def _is_contiguous(tensor: torch.Tensor) -> bool:
# jit is oh so lovely :/
# if torch.jit.is_tracing():
# return True
if torch.jit.is_scripting():
return tensor.is_contiguous()
else:
return tensor.is_contiguous(memory_format=torch.contiguous_format)
@torch.jit.script
def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
x = (x - u) * torch.rsqrt(s + eps)
x = x * weight[:, None, None] + bias[:, None, None]
return x
class LayerNormExp2d(nn.LayerNorm):
""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
Experimental implementation w/ manual norm for tensors non-contiguous tensors.
This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last
layout. However, benefits are not always clear and can perform worse on other GPUs.
"""
def __init__(self, num_channels, eps=1e-6):
super().__init__(num_channels, eps=eps)
def forward(self, x) -> torch.Tensor:
if _is_contiguous(x):
x = F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
else:
x = _layer_norm_cf(x, self.weight, self.bias, self.eps)
return x

@ -1,6 +1,6 @@
""" Normalization + Activation Layers
"""
from typing import Union, List
from typing import Union, List, Optional, Any
import torch
from torch import nn as nn
@ -18,10 +18,29 @@ class BatchNormAct2d(nn.BatchNorm2d):
instead of composing it as a .bn member.
"""
def __init__(
self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
super(BatchNormAct2d, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
device=None,
dtype=None
):
try:
factory_kwargs = {'device': device, 'dtype': dtype}
super(BatchNormAct2d, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats,
**factory_kwargs
)
except TypeError:
# NOTE for backwards compat with old PyTorch w/o factory device/dtype support
super(BatchNormAct2d, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
@ -81,6 +100,62 @@ class BatchNormAct2d(nn.BatchNorm2d):
return x
class SyncBatchNormAct(nn.SyncBatchNorm):
# Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
# This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
# but ONLY when used in conjunction with the timm conversion function below.
# Do not create this module directly or use the PyTorch conversion function.
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = super().forward(x) # SyncBN doesn't work with torchscript anyways, so this is fine
if hasattr(self, "drop"):
x = self.drop(x)
if hasattr(self, "act"):
x = self.act(x)
return x
def convert_sync_batchnorm(module, process_group=None):
# convert both BatchNorm and BatchNormAct layers to Synchronized variants
module_output = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
if isinstance(module, BatchNormAct2d):
# convert timm norm + act layer
module_output = SyncBatchNormAct(
module.num_features,
module.eps,
module.momentum,
module.affine,
module.track_running_stats,
process_group=process_group,
)
# set act and drop attr from the original module
module_output.act = module.act
module_output.drop = module.drop
else:
# convert standard BatchNorm layers
module_output = torch.nn.SyncBatchNorm(
module.num_features,
module.eps,
module.momentum,
module.affine,
module.track_running_stats,
process_group,
)
if module.affine:
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
if hasattr(module, "qconfig"):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, convert_sync_batchnorm(child, process_group))
del module
return module_output
def group_norm_tpu(x, w, b, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False, flatten: bool = False):
# This is a workaround for some odd behaviour running on PyTorch XLA w/ TPUs.
x_shape = x.shape

@ -36,7 +36,7 @@ class TestTimePoolHead(nn.Module):
return x.view(x.size(0), -1)
def apply_test_time_pool(model, config, use_test_size=True):
def apply_test_time_pool(model, config, use_test_size=False):
test_time_pool = False
if not hasattr(model, 'default_cfg') or not model.default_cfg:
return model, False

@ -49,6 +49,11 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
applied while sampling the normal with mean/std applied, therefore a, b args
should be adjusted to match the range of mean, std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
@ -62,6 +67,35 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
and the result is subsquently scaled and shifted by the mean and std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
_no_grad_trunc_normal_(tensor, 0, 1.0, a, b)
with torch.no_grad():
tensor.mul_(std).add_(mean)
return tensor
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == 'fan_in':
@ -75,7 +109,7 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978)
elif distribution == "normal":
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":

@ -1,7 +1,8 @@
""" MobileViT
Paper:
`MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178
V1: `MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178
V2: `Separable Self-attention for Mobile Vision Transformers` - https://arxiv.org/abs/2206.02680
MobileVitBlock and checkpoints adapted from https://github.com/apple/ml-cvnets (original copyright below)
License: https://github.com/apple/ml-cvnets/blob/main/LICENSE (Apple open source)
@ -13,7 +14,7 @@ Rest of code, ByobNet, and Transformer block hacked together by / Copyright 2022
# Copyright (C) 2020 Apple Inc. All Rights Reserved.
#
import math
from typing import Union, Callable, Dict, Tuple, Optional
from typing import Union, Callable, Dict, Tuple, Optional, Sequence
import torch
from torch import nn
@ -21,7 +22,7 @@ import torch.nn.functional as F
from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups
from .fx_features import register_notrace_module
from .layers import to_2tuple, make_divisible
from .layers import to_2tuple, make_divisible, LayerNorm2d, GroupNorm1, ConvMlp, DropPath
from .vision_transformer import Block as TransformerBlock
from .helpers import build_model_with_cfg
from .registry import register_model
@ -48,6 +49,48 @@ default_cfgs = {
'mobilevit_s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_s-38a5a959.pth'),
'semobilevit_s': _cfg(),
'mobilevitv2_050': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_050-49951ee2.pth',
crop_pct=0.888),
'mobilevitv2_075': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_075-b5556ef6.pth',
crop_pct=0.888),
'mobilevitv2_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_100-e464ef3b.pth',
crop_pct=0.888),
'mobilevitv2_125': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_125-0ae35027.pth',
crop_pct=0.888),
'mobilevitv2_150': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150-737c5019.pth',
crop_pct=0.888),
'mobilevitv2_175': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175-16462ee2.pth',
crop_pct=0.888),
'mobilevitv2_200': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200-b3422f67.pth',
crop_pct=0.888),
'mobilevitv2_150_in22ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_in22ft1k-0b555d7b.pth',
crop_pct=0.888),
'mobilevitv2_175_in22ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_in22ft1k-4117fa1f.pth',
crop_pct=0.888),
'mobilevitv2_200_in22ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_in22ft1k-1d7c8927.pth',
crop_pct=0.888),
'mobilevitv2_150_384_in22ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_384_in22ft1k-9e142854.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'mobilevitv2_175_384_in22ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_384_in22ft1k-059cbe56.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'mobilevitv2_200_384_in22ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_384_in22ft1k-32c87503.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
}
@ -72,6 +115,40 @@ def _mobilevit_block(d, c, s, transformer_dim, transformer_depth, patch_size=4,
)
def _mobilevitv2_block(d, c, s, transformer_depth, patch_size=2, br=2.0, transformer_br=0.5):
# inverted residual + mobilevit blocks as per MobileViT network
return (
_inverted_residual_block(d=d, c=c, s=s, br=br),
ByoBlockCfg(
type='mobilevit2', d=1, c=c, s=1, br=transformer_br, gs=1,
block_kwargs=dict(
transformer_depth=transformer_depth,
patch_size=patch_size)
)
)
def _mobilevitv2_cfg(multiplier=1.0):
chs = (64, 128, 256, 384, 512)
if multiplier != 1.0:
chs = tuple([int(c * multiplier) for c in chs])
cfg = ByoModelCfg(
blocks=(
_inverted_residual_block(d=1, c=chs[0], s=1, br=2.0),
_inverted_residual_block(d=2, c=chs[1], s=2, br=2.0),
_mobilevitv2_block(d=1, c=chs[2], s=2, transformer_depth=2),
_mobilevitv2_block(d=1, c=chs[3], s=2, transformer_depth=4),
_mobilevitv2_block(d=1, c=chs[4], s=2, transformer_depth=3),
),
stem_chs=int(32 * multiplier),
stem_type='3x3',
stem_pool='',
downsample='',
act_layer='silu',
)
return cfg
model_cfgs = dict(
mobilevit_xxs=ByoModelCfg(
blocks=(
@ -137,11 +214,19 @@ model_cfgs = dict(
attn_kwargs=dict(rd_ratio=1/8),
num_features=640,
),
mobilevitv2_050=_mobilevitv2_cfg(.50),
mobilevitv2_075=_mobilevitv2_cfg(.75),
mobilevitv2_125=_mobilevitv2_cfg(1.25),
mobilevitv2_100=_mobilevitv2_cfg(1.0),
mobilevitv2_150=_mobilevitv2_cfg(1.5),
mobilevitv2_175=_mobilevitv2_cfg(1.75),
mobilevitv2_200=_mobilevitv2_cfg(2.0),
)
@register_notrace_module
class MobileViTBlock(nn.Module):
class MobileVitBlock(nn.Module):
""" MobileViT block
Paper: https://arxiv.org/abs/2110.02178?context=cs.LG
"""
@ -165,9 +250,9 @@ class MobileViTBlock(nn.Module):
drop_path_rate: float = 0.,
layers: LayerFn = None,
transformer_norm_layer: Callable = nn.LayerNorm,
downsample: str = ''
**kwargs, # eat unused args
):
super(MobileViTBlock, self).__init__()
super(MobileVitBlock, self).__init__()
layers = layers or LayerFn()
groups = num_groups(group_size, in_chs)
@ -241,7 +326,270 @@ class MobileViTBlock(nn.Module):
return x
register_block('mobilevit', MobileViTBlock)
class LinearSelfAttention(nn.Module):
"""
This layer applies a self-attention with linear complexity, as described in `https://arxiv.org/abs/2206.02680`
This layer can be used for self- as well as cross-attention.
Args:
embed_dim (int): :math:`C` from an expected input of size :math:`(N, C, H, W)`
attn_drop (float): Dropout value for context scores. Default: 0.0
bias (bool): Use bias in learnable layers. Default: True
Shape:
- Input: :math:`(N, C, P, N)` where :math:`N` is the batch size, :math:`C` is the input channels,
:math:`P` is the number of pixels in the patch, and :math:`N` is the number of patches
- Output: same as the input
.. note::
For MobileViTv2, we unfold the feature map [B, C, H, W] into [B, C, P, N] where P is the number of pixels
in a patch and N is the number of patches. Because channel is the first dimension in this unfolded tensor,
we use point-wise convolution (instead of a linear layer). This avoids a transpose operation (which may be
expensive on resource-constrained devices) that may be required to convert the unfolded tensor from
channel-first to channel-last format in case of a linear layer.
"""
def __init__(
self,
embed_dim: int,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
self.embed_dim = embed_dim
self.qkv_proj = nn.Conv2d(
in_channels=embed_dim,
out_channels=1 + (2 * embed_dim),
bias=bias,
kernel_size=1,
)
self.attn_drop = nn.Dropout(attn_drop)
self.out_proj = nn.Conv2d(
in_channels=embed_dim,
out_channels=embed_dim,
bias=bias,
kernel_size=1,
)
self.out_drop = nn.Dropout(proj_drop)
def _forward_self_attn(self, x: torch.Tensor) -> torch.Tensor:
# [B, C, P, N] --> [B, h + 2d, P, N]
qkv = self.qkv_proj(x)
# Project x into query, key and value
# Query --> [B, 1, P, N]
# value, key --> [B, d, P, N]
query, key, value = qkv.split([1, self.embed_dim, self.embed_dim], dim=1)
# apply softmax along N dimension
context_scores = F.softmax(query, dim=-1)
context_scores = self.attn_drop(context_scores)
# Compute context vector
# [B, d, P, N] x [B, 1, P, N] -> [B, d, P, N] --> [B, d, P, 1]
context_vector = (key * context_scores).sum(dim=-1, keepdim=True)
# combine context vector with values
# [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
out = F.relu(value) * context_vector.expand_as(value)
out = self.out_proj(out)
out = self.out_drop(out)
return out
@torch.jit.ignore()
def _forward_cross_attn(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
# x --> [B, C, P, N]
# x_prev = [B, C, P, M]
batch_size, in_dim, kv_patch_area, kv_num_patches = x.shape
q_patch_area, q_num_patches = x.shape[-2:]
assert (
kv_patch_area == q_patch_area
), "The number of pixels in a patch for query and key_value should be the same"
# compute query, key, and value
# [B, C, P, M] --> [B, 1 + d, P, M]
qk = F.conv2d(
x_prev,
weight=self.qkv_proj.weight[:self.embed_dim + 1],
bias=self.qkv_proj.bias[:self.embed_dim + 1],
)
# [B, 1 + d, P, M] --> [B, 1, P, M], [B, d, P, M]
query, key = qk.split([1, self.embed_dim], dim=1)
# [B, C, P, N] --> [B, d, P, N]
value = F.conv2d(
x,
weight=self.qkv_proj.weight[self.embed_dim + 1],
bias=self.qkv_proj.bias[self.embed_dim + 1] if self.qkv_proj.bias is not None else None,
)
# apply softmax along M dimension
context_scores = F.softmax(query, dim=-1)
context_scores = self.attn_drop(context_scores)
# compute context vector
# [B, d, P, M] * [B, 1, P, M] -> [B, d, P, M] --> [B, d, P, 1]
context_vector = (key * context_scores).sum(dim=-1, keepdim=True)
# combine context vector with values
# [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
out = F.relu(value) * context_vector.expand_as(value)
out = self.out_proj(out)
out = self.out_drop(out)
return out
def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
if x_prev is None:
return self._forward_self_attn(x)
else:
return self._forward_cross_attn(x, x_prev=x_prev)
class LinearTransformerBlock(nn.Module):
"""
This class defines the pre-norm transformer encoder with linear self-attention in `MobileViTv2 paper <>`_
Args:
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, P, N)`
mlp_ratio (float): Inner dimension ratio of the FFN relative to embed_dim
drop (float): Dropout rate. Default: 0.0
attn_drop (float): Dropout rate for attention in multi-head attention. Default: 0.0
drop_path (float): Stochastic depth rate Default: 0.0
norm_layer (Callable): Normalization layer. Default: layer_norm_2d
Shape:
- Input: :math:`(B, C_{in}, P, N)` where :math:`B` is batch size, :math:`C_{in}` is input embedding dim,
:math:`P` is number of pixels in a patch, and :math:`N` is number of patches,
- Output: same shape as the input
"""
def __init__(
self,
embed_dim: int,
mlp_ratio: float = 2.0,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
act_layer=None,
norm_layer=None,
) -> None:
super().__init__()
act_layer = act_layer or nn.SiLU
norm_layer = norm_layer or GroupNorm1
self.norm1 = norm_layer(embed_dim)
self.attn = LinearSelfAttention(embed_dim=embed_dim, attn_drop=attn_drop, proj_drop=drop)
self.drop_path1 = DropPath(drop_path)
self.norm2 = norm_layer(embed_dim)
self.mlp = ConvMlp(
in_features=embed_dim,
hidden_features=int(embed_dim * mlp_ratio),
act_layer=act_layer,
drop=drop)
self.drop_path2 = DropPath(drop_path)
def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
if x_prev is None:
# self-attention
x = x + self.drop_path1(self.attn(self.norm1(x)))
else:
# cross-attention
res = x
x = self.norm1(x) # norm
x = self.attn(x, x_prev) # attn
x = self.drop_path1(x) + res # residual
# Feed forward network
x = x + self.drop_path2(self.mlp(self.norm2(x)))
return x
@register_notrace_module
class MobileVitV2Block(nn.Module):
"""
This class defines the `MobileViTv2 block <>`_
"""
def __init__(
self,
in_chs: int,
out_chs: Optional[int] = None,
kernel_size: int = 3,
bottle_ratio: float = 1.0,
group_size: Optional[int] = 1,
dilation: Tuple[int, int] = (1, 1),
mlp_ratio: float = 2.0,
transformer_dim: Optional[int] = None,
transformer_depth: int = 2,
patch_size: int = 8,
attn_drop: float = 0.,
drop: int = 0.,
drop_path_rate: float = 0.,
layers: LayerFn = None,
transformer_norm_layer: Callable = GroupNorm1,
**kwargs, # eat unused args
):
super(MobileVitV2Block, self).__init__()
layers = layers or LayerFn()
groups = num_groups(group_size, in_chs)
out_chs = out_chs or in_chs
transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs)
self.conv_kxk = layers.conv_norm_act(
in_chs, in_chs, kernel_size=kernel_size,
stride=1, groups=groups, dilation=dilation[0])
self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False)
self.transformer = nn.Sequential(*[
LinearTransformerBlock(
transformer_dim,
mlp_ratio=mlp_ratio,
attn_drop=attn_drop,
drop=drop,
drop_path=drop_path_rate,
act_layer=layers.act,
norm_layer=transformer_norm_layer
)
for _ in range(transformer_depth)
])
self.norm = transformer_norm_layer(transformer_dim)
self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1, apply_act=False)
self.patch_size = to_2tuple(patch_size)
self.patch_area = self.patch_size[0] * self.patch_size[1]
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
patch_h, patch_w = self.patch_size
new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w
num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w
num_patches = num_patch_h * num_patch_w # N
if new_h != H or new_w != W:
x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=True)
# Local representation
x = self.conv_kxk(x)
x = self.conv_1x1(x)
# Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
C = x.shape[1]
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
x = x.reshape(B, C, -1, num_patches)
# Global representations
x = self.transformer(x)
x = self.norm(x)
# Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
x = self.conv_proj(x)
return x
register_block('mobilevit', MobileVitBlock)
register_block('mobilevit2', MobileVitV2Block)
def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs):
@ -252,6 +600,14 @@ def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs):
**kwargs)
def _create_mobilevit2(variant, cfg_variant=None, pretrained=False, **kwargs):
return build_model_with_cfg(
ByobNet, variant, pretrained,
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
feature_cfg=dict(flatten_sequential=True),
**kwargs)
@register_model
def mobilevit_xxs(pretrained=False, **kwargs):
return _create_mobilevit('mobilevit_xxs', pretrained=pretrained, **kwargs)
@ -269,4 +625,75 @@ def mobilevit_s(pretrained=False, **kwargs):
@register_model
def semobilevit_s(pretrained=False, **kwargs):
return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs)
return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_050(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_050', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_075(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_075', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_100(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_100', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_125(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_125', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_150(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_150', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_175(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_175', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_200(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_200', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_150_in22ft1k(pretrained=False, **kwargs):
return _create_mobilevit(
'mobilevitv2_150_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_175_in22ft1k(pretrained=False, **kwargs):
return _create_mobilevit(
'mobilevitv2_175_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_200_in22ft1k(pretrained=False, **kwargs):
return _create_mobilevit(
'mobilevitv2_200_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_150_384_in22ft1k(pretrained=False, **kwargs):
return _create_mobilevit(
'mobilevitv2_150_384_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_175_384_in22ft1k(pretrained=False, **kwargs):
return _create_mobilevit(
'mobilevitv2_175_384_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_200_384_in22ft1k(pretrained=False, **kwargs):
return _create_mobilevit(
'mobilevitv2_200_384_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs)

@ -26,7 +26,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, checkpoint_seq
from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp
from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp, GroupNorm1
from .registry import register_model
@ -80,15 +80,6 @@ class PatchEmbed(nn.Module):
return x
class GroupNorm1(nn.GroupNorm):
""" Group Normalization with 1 group.
Input: tensor in shape [B, C, H, W]
"""
def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
class Pooling(nn.Module):
def __init__(self, pool_size=3):
super().__init__()

@ -35,6 +35,16 @@ def _cfg(url='', **kwargs):
default_cfgs = {
# ResNet and Wide ResNet
'resnet10t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet10t_176_c3-f3215ab1.pth',
input_size=(3, 176, 176), pool_size=(6, 6),
test_crop_pct=0.95, test_input_size=(3, 224, 224),
first_conv='conv1.0'),
'resnet14t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet14t_176_c3-c4ed2c37.pth',
input_size=(3, 176, 176), pool_size=(6, 6),
test_crop_pct=0.95, test_input_size=(3, 224, 224),
first_conv='conv1.0'),
'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'),
'resnet18d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet18d_ra2-48a79e06.pth',
@ -262,6 +272,9 @@ default_cfgs = {
'resnetblur101d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0'),
'resnetaa50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetaa50_a1h-4cf422b3.pth',
test_input_size=(3, 288, 288), test_crop_pct=1.0, interpolation='bicubic'),
'resnetaa50d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0'),
@ -723,6 +736,24 @@ def _create_resnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
@register_model
def resnet10t(pretrained=False, **kwargs):
"""Constructs a ResNet-10-T model.
"""
model_args = dict(
block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs)
return _create_resnet('resnet10t', pretrained, **model_args)
@register_model
def resnet14t(pretrained=False, **kwargs):
"""Constructs a ResNet-14-T model.
"""
model_args = dict(
block=Bottleneck, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs)
return _create_resnet('resnet14t', pretrained, **model_args)
@register_model
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
@ -1436,6 +1467,14 @@ def resnetblur101d(pretrained=False, **kwargs):
return _create_resnet('resnetblur101d', pretrained, **model_args)
@register_model
def resnetaa50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model with avgpool anti-aliasing
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, **kwargs)
return _create_resnet('resnetaa50', pretrained, **model_args)
@register_model
def resnetaa50d(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D model with avgpool anti-aliasing

@ -325,8 +325,8 @@ class VisionTransformer(nn.Module):
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='',
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
"""
Args:
img_size (int, tuple): input image size
@ -360,15 +360,17 @@ class VisionTransformer(nn.Module):
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1 if class_token else 0
self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class
self.grad_checkpointing = False
self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
@ -428,11 +430,24 @@ class VisionTransformer(nn.Module):
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def _pos_embed(self, x):
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + self.pos_embed
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.pos_embed
return self.pos_drop(x)
def forward_features(self, x):
x = self.patch_embed(x)
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self._pos_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
@ -442,7 +457,7 @@ class VisionTransformer(nn.Module):
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
return x if pre_logits else self.head(x)
@ -556,7 +571,11 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
if pos_embed_w.shape != model.pos_embed.shape:
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
pos_embed_w,
model.pos_embed,
getattr(model, 'num_prefix_tokens', 1),
model.patch_embed.grid_size
)
model.pos_embed.copy_(pos_embed_w)
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
@ -585,16 +604,16 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
ntok_new = posemb_new.shape[1]
if num_tokens:
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
ntok_new -= num_tokens
if num_prefix_tokens:
posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:]
ntok_new -= num_prefix_tokens
else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
if not len(gs_new): # backwards compatibility
gs_new = [int(math.sqrt(ntok_new))] * 2
@ -603,25 +622,34 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
return posemb
def checkpoint_filter_fn(state_dict, model):
def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
import re
out_dict = {}
if 'model' in state_dict:
# For deit models
state_dict = state_dict['model']
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
# For old models that I trained prior to conv based patchification
O, I, H, W = model.patch_embed.proj.weight.shape
v = v.reshape(O, -1, H, W)
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
# To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(
v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
v,
model.pos_embed,
getattr(model, 'num_prefix_tokens', 1),
model.patch_embed.grid_size
)
elif adapt_layer_scale and 'gamma_' in k:
# remap layer-scale gamma into sub-module (deit3 models)
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
elif 'pre_logits' in k:
# NOTE representation layer removed as not used in latest 21k/1k pretrained weights
continue
@ -633,7 +661,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
model = build_model_with_cfg(
VisionTransformer, variant, pretrained,
pretrained_cfg=pretrained_cfg,

@ -8,6 +8,7 @@ import math
import logging
from functools import partial
from collections import OrderedDict
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
@ -47,9 +48,16 @@ default_cfgs = {
'vit_relpos_base_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'),
'vit_srelpos_small_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_small_patch16_224-sw-6cdb8849.pth'),
'vit_srelpos_medium_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_medium_patch16_224-sw-ad702b8c.pth'),
'vit_relpos_medium_patch16_cls_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_cls_224-sw-cfe8e259.pth'),
'vit_relpos_base_patch16_cls_224': _cfg(
url=''),
'vit_relpos_base_patch16_gapcls_224': _cfg(
'vit_relpos_base_patch16_clsgap_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'),
'vit_relpos_small_patch16_rpn_224': _cfg(url=''),
@ -59,35 +67,43 @@ default_cfgs = {
}
def gen_relative_position_index(win_size: Tuple[int, int], class_token: int = 0) -> torch.Tensor:
# cut and paste w/ modifications from swin / beit codebase
# cls to token & token 2 cls & cls to cls
def gen_relative_position_index(
q_size: Tuple[int, int],
k_size: Tuple[int, int] = None,
class_token: bool = False) -> torch.Tensor:
# Adapted with significant modifications from Swin / BeiT codebases
# get pair-wise relative position index for each token inside the window
window_area = win_size[0] * win_size[1]
coords = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1) # 2, Wh, Ww
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += win_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += win_size[1] - 1
relative_coords[:, :, 0] *= 2 * win_size[1] - 1
q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww
if k_size is None:
k_coords = q_coords
k_size = q_size
else:
# different q vs k sizes is a WIP
k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1)
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2
_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0)
if class_token:
num_relative_distance = (2 * win_size[0] - 1) * (2 * win_size[1] - 1) + 3
relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
# NOTE not intended or tested with MLP log-coords
max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1]))
num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1
else:
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
return relative_position_index
return relative_position_index.contiguous()
def gen_relative_log_coords(
win_size: Tuple[int, int],
pretrained_win_size: Tuple[int, int] = (0, 0),
mode='swin'
mode='swin',
):
# as per official swin-v2 impl, supporting timm swin-v2-cr coords as well
assert mode in ('swin', 'cr', 'rw')
# as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
@ -100,12 +116,22 @@ def gen_relative_log_coords(
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # normalize to -8, 8
scale = math.log2(8)
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs()) / math.log2(8)
else:
# FIXME we should support a form of normalization (to -1/1) for this mode?
scale = math.log2(math.e)
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs()) / scale
if mode == 'rw':
# cr w/ window size normalization -> [-1,1] log coords
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # scale to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs())
relative_coords_table /= math.log2(9) # -> [-1, 1]
else:
# mode == 'cr'
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
1.0 + relative_coords_table.abs())
return relative_coords_table
@ -115,19 +141,29 @@ class RelPosMlp(nn.Module):
window_size,
num_heads=8,
hidden_dim=128,
class_token=False,
prefix_tokens=0,
mode='cr',
pretrained_window_size=(0, 0)
):
super().__init__()
self.window_size = window_size
self.window_area = self.window_size[0] * self.window_size[1]
self.class_token = 1 if class_token else 0
self.prefix_tokens = prefix_tokens
self.num_heads = num_heads
self.bias_shape = (self.window_area,) * 2 + (num_heads,)
self.apply_sigmoid = mode == 'swin'
if mode == 'swin':
self.bias_act = nn.Sigmoid()
self.bias_gain = 16
mlp_bias = (True, False)
elif mode == 'rw':
self.bias_act = nn.Tanh()
self.bias_gain = 4
mlp_bias = True
else:
self.bias_act = nn.Identity()
self.bias_gain = None
mlp_bias = True
mlp_bias = (True, False) if mode == 'swin' else True
self.mlp = Mlp(
2, # x, y
hidden_features=hidden_dim,
@ -155,10 +191,11 @@ class RelPosMlp(nn.Module):
self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.view(self.bias_shape)
relative_position_bias = relative_position_bias.permute(2, 0, 1)
if self.apply_sigmoid:
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
if self.class_token:
relative_position_bias = F.pad(relative_position_bias, [self.class_token, 0, self.class_token, 0])
relative_position_bias = self.bias_act(relative_position_bias)
if self.bias_gain is not None:
relative_position_bias = self.bias_gain * relative_position_bias
if self.prefix_tokens:
relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0])
return relative_position_bias.unsqueeze(0).contiguous()
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
@ -167,18 +204,18 @@ class RelPosMlp(nn.Module):
class RelPosBias(nn.Module):
def __init__(self, window_size, num_heads, class_token=False):
def __init__(self, window_size, num_heads, prefix_tokens=0):
super().__init__()
assert prefix_tokens <= 1
self.window_size = window_size
self.window_area = window_size[0] * window_size[1]
self.class_token = 1 if class_token else 0
self.bias_shape = (self.window_area + self.class_token,) * 2 + (num_heads,)
self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,)
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * self.class_token
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
self.register_buffer(
"relative_position_index",
gen_relative_position_index(self.window_size, class_token=self.class_token),
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0),
persistent=False,
)
@ -306,11 +343,32 @@ class VisionTransformerRelPos(nn.Module):
"""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-6,
class_token=False, fc_norm=False, rel_pos_type='mlp', shared_rel_pos=False, rel_pos_dim=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip',
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock):
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool='avg',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
init_values=1e-6,
class_token=False,
fc_norm=False,
rel_pos_type='mlp',
rel_pos_dim=None,
shared_rel_pos=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='skip',
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
block_fn=RelPosBlock
):
"""
Args:
img_size (int, tuple): input image size
@ -345,19 +403,22 @@ class VisionTransformerRelPos(nn.Module):
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1 if class_token else 0
self.num_prefix_tokens = 1 if class_token else 0
self.grad_checkpointing = False
self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
feat_size = self.patch_embed.grid_size
rel_pos_args = dict(window_size=feat_size, class_token=class_token)
rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens)
if rel_pos_type.startswith('mlp'):
if rel_pos_dim:
rel_pos_args['hidden_dim'] = rel_pos_dim
# FIXME experimenting with different relpos log coord configs
if 'swin' in rel_pos_type:
rel_pos_args['mode'] = 'swin'
elif 'rw' in rel_pos_type:
rel_pos_args['mode'] = 'rw'
rel_pos_cls = partial(RelPosMlp, **rel_pos_args)
else:
rel_pos_cls = partial(RelPosBias, **rel_pos_args)
@ -367,7 +428,7 @@ class VisionTransformerRelPos(nn.Module):
# NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both...
rel_pos_cls = None
self.cls_token = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) if self.num_tokens else None
self.cls_token = nn.Parameter(torch.zeros(1, self.num_prefix_tokens, embed_dim)) if class_token else None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
@ -434,7 +495,7 @@ class VisionTransformerRelPos(nn.Module):
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
return x if pre_logits else self.head(x)
@ -502,6 +563,41 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs):
return model
@register_model
def vit_srelpos_small_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, fc_norm=False,
rel_pos_dim=384, shared_rel_pos=True, **kwargs)
model = _create_vision_transformer_relpos('vit_srelpos_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_srelpos_medium_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=False,
rel_pos_dim=512, shared_rel_pos=True, **kwargs)
model = _create_vision_transformer_relpos(
'vit_srelpos_medium_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_relpos_medium_patch16_cls_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-M/16) w/ relative log-coord position, class token present
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=False,
rel_pos_dim=256, class_token=True, global_pool='token', **kwargs)
model = _create_vision_transformer_relpos(
'vit_relpos_medium_patch16_cls_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
@ -514,14 +610,14 @@ def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs):
@register_model
def vit_relpos_base_patch16_gapcls_224(pretrained=False, **kwargs):
def vit_relpos_base_patch16_clsgap_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
NOTE this config is a bit of a mistake, class token was enabled but global avg-pool w/ fc-norm was not disabled
Leaving here for comparisons w/ a future re-train as it performs quite well.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, class_token=True, **kwargs)
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_gapcls_224', pretrained=pretrained, **model_kwargs)
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_clsgap_224', pretrained=pretrained, **model_kwargs)
return model

Loading…
Cancel
Save