Merge pull request #1327 from rwightman/edgenext_csp_and_more

EdgeNeXt, additional DarkNets, and more
pull/1345/head
Ross Wightman 2 years ago committed by GitHub
commit 1ccce50d48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,24 +6,23 @@ An inference and train step benchmark script for timm models.
Hacked together by Ross Wightman (https://github.com/rwightman) Hacked together by Ross Wightman (https://github.com/rwightman)
""" """
import argparse import argparse
import os
import csv import csv
import json import json
import time
import logging import logging
import torch import time
import torch.nn as nn
import torch.nn.parallel
from collections import OrderedDict from collections import OrderedDict
from contextlib import suppress from contextlib import suppress
from functools import partial from functools import partial
import torch
import torch.nn as nn
import torch.nn.parallel
from timm.data import resolve_data_config
from timm.models import create_model, is_model, list_models from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer_v2 from timm.optim import create_optimizer_v2
from timm.data import resolve_data_config
from timm.utils import setup_default_logging, set_jit_fuser from timm.utils import setup_default_logging, set_jit_fuser
has_apex = False has_apex = False
try: try:
from apex import amp from apex import amp
@ -71,6 +70,8 @@ parser.add_argument('--bench', default='both', type=str,
help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'") help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'")
parser.add_argument('--detail', action='store_true', default=False, parser.add_argument('--detail', action='store_true', default=False,
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False') help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
parser.add_argument('--no-retry', action='store_true', default=False,
help='Do not decay batch size and retry on error.')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)') help='Output csv file for validation results (summary)')
parser.add_argument('--num-warm-iter', default=10, type=int, parser.add_argument('--num-warm-iter', default=10, type=int,
@ -169,10 +170,9 @@ def resolve_precision(precision: str):
def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False): def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
macs, _ = get_model_profile( _, macs, _ = get_model_profile(
model=model, model=model,
input_res=(batch_size,) + input_size, # input shape or input to the input_constructor input_shape=(batch_size,) + input_size, # input shape/resolution
input_constructor=None, # if specified, a constructor taking input_res is used as input to the model
print_profile=detailed, # prints the model graph with the measured profile attached to each module print_profile=detailed, # prints the model graph with the measured profile attached to each module
detailed=detailed, # print the detailed profile detailed=detailed, # print the detailed profile
warm_up=10, # the number of warm-ups before measuring the time of each module warm_up=10, # the number of warm-ups before measuring the time of each module
@ -197,8 +197,19 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
class BenchmarkRunner: class BenchmarkRunner:
def __init__( def __init__(
self, model_name, detail=False, device='cuda', torchscript=False, aot_autograd=False, precision='float32', self,
fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs): model_name,
detail=False,
device='cuda',
torchscript=False,
aot_autograd=False,
precision='float32',
fuser='',
num_warm_iter=10,
num_bench_iter=50,
use_train_size=False,
**kwargs
):
self.model_name = model_name self.model_name = model_name
self.detail = detail self.detail = detail
self.device = device self.device = device
@ -225,11 +236,12 @@ class BenchmarkRunner:
self.num_classes = self.model.num_classes self.num_classes = self.model.num_classes
self.param_count = count_params(self.model) self.param_count = count_params(self.model)
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count)) _logger.info('Model %s created, param count: %d' % (model_name, self.param_count))
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
self.scripted = False self.scripted = False
if torchscript: if torchscript:
self.model = torch.jit.script(self.model) self.model = torch.jit.script(self.model)
self.scripted = True self.scripted = True
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
self.input_size = data_config['input_size'] self.input_size = data_config['input_size']
self.batch_size = kwargs.pop('batch_size', 256) self.batch_size = kwargs.pop('batch_size', 256)
@ -255,7 +267,13 @@ class BenchmarkRunner:
class InferenceBenchmarkRunner(BenchmarkRunner): class InferenceBenchmarkRunner(BenchmarkRunner):
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs): def __init__(
self,
model_name,
device='cuda',
torchscript=False,
**kwargs
):
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs) super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
self.model.eval() self.model.eval()
@ -324,7 +342,13 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
class TrainBenchmarkRunner(BenchmarkRunner): class TrainBenchmarkRunner(BenchmarkRunner):
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs): def __init__(
self,
model_name,
device='cuda',
torchscript=False,
**kwargs
):
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs) super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
self.model.train() self.model.train()
@ -491,7 +515,7 @@ def decay_batch_exp(batch_size, factor=0.5, divisor=16):
return max(0, int(out_batch_size)) return max(0, int(out_batch_size))
def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs): def _try_run(model_name, bench_fn, bench_kwargs, initial_batch_size, no_batch_size_retry=False):
batch_size = initial_batch_size batch_size = initial_batch_size
results = dict() results = dict()
error_str = 'Unknown' error_str = 'Unknown'
@ -506,8 +530,11 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
if 'channels_last' in error_str: if 'channels_last' in error_str:
_logger.error(f'{model_name} not supported in channels_last, skipping.') _logger.error(f'{model_name} not supported in channels_last, skipping.')
break break
_logger.warning(f'"{error_str}" while running benchmark. Reducing batch size to {batch_size} for retry.') _logger.error(f'"{error_str}" while running benchmark.')
if no_batch_size_retry:
break
batch_size = decay_batch_exp(batch_size) batch_size = decay_batch_exp(batch_size)
_logger.warning(f'Reducing batch size to {batch_size} for retry.')
results['error'] = error_str results['error'] = error_str
return results return results
@ -549,7 +576,13 @@ def benchmark(args):
model_results = OrderedDict(model=model) model_results = OrderedDict(model=model)
for prefix, bench_fn in zip(prefixes, bench_fns): for prefix, bench_fn in zip(prefixes, bench_fns):
run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs) run_results = _try_run(
model,
bench_fn,
bench_kwargs=bench_kwargs,
initial_batch_size=batch_size,
no_batch_size_retry=args.no_retry,
)
if prefix and 'error' not in run_results: if prefix and 'error' not in run_results:
run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()} run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
model_results.update(run_results) model_results.update(run_results)

@ -6,7 +6,8 @@ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset from .dataset_factory import create_dataset
from .loader import create_loader from .loader import create_loader
from .mixup import Mixup, FastCollateMixup from .mixup import Mixup, FastCollateMixup
from .parsers import create_parser from .parsers import create_parser,\
get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
from .real_labels import RealLabelsImagenet from .real_labels import RealLabelsImagenet
from .transforms import * from .transforms import *
from .transforms_factory import create_transform from .transforms_factory import create_transform

@ -64,11 +64,15 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v
new_config['std'] = default_cfg['std'] new_config['std'] = default_cfg['std']
# resolve default crop percentage # resolve default crop percentage
new_config['crop_pct'] = DEFAULT_CROP_PCT crop_pct = DEFAULT_CROP_PCT
if 'crop_pct' in args and args['crop_pct'] is not None: if 'crop_pct' in args and args['crop_pct'] is not None:
new_config['crop_pct'] = args['crop_pct'] crop_pct = args['crop_pct']
else:
if use_test_size and 'test_crop_pct' in default_cfg:
crop_pct = default_cfg['test_crop_pct']
elif 'crop_pct' in default_cfg: elif 'crop_pct' in default_cfg:
new_config['crop_pct'] = default_cfg['crop_pct'] crop_pct = default_cfg['crop_pct']
new_config['crop_pct'] = crop_pct
if verbose: if verbose:
_logger.info('Data processing configuration for current model + dataset:') _logger.info('Data processing configuration for current model + dataset:')

@ -26,8 +26,8 @@ _TORCH_BASIC_DS = dict(
kmnist=KMNIST, kmnist=KMNIST,
fashion_mnist=FashionMNIST, fashion_mnist=FashionMNIST,
) )
_TRAIN_SYNONYM = {'train', 'training'} _TRAIN_SYNONYM = dict(train=None, training=None)
_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'} _EVAL_SYNONYM = dict(val=None, valid=None, validation=None, eval=None, evaluation=None)
def _search_split(root, split): def _search_split(root, split):

@ -1 +1,2 @@
from .parser_factory import create_parser from .parser_factory import create_parser
from .img_extensions import *

@ -1 +0,0 @@
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg')

@ -0,0 +1,50 @@
from copy import deepcopy
__all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions']
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use
_IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync
def _set_extensions(extensions):
global IMG_EXTENSIONS
global _IMG_EXTENSIONS_SET
dedupe = set() # NOTE de-duping tuple while keeping original order
IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x))
_IMG_EXTENSIONS_SET = set(extensions)
def _valid_extension(x: str):
return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.')
def is_img_extension(ext):
return ext in _IMG_EXTENSIONS_SET
def get_img_extensions(as_set=False):
return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS)
def set_img_extensions(extensions):
assert len(extensions)
for x in extensions:
assert _valid_extension(x)
_set_extensions(extensions)
def add_img_extensions(ext):
if not isinstance(ext, (list, tuple, set)):
ext = (ext,)
for x in ext:
assert _valid_extension(x)
extensions = IMG_EXTENSIONS + tuple(ext)
_set_extensions(extensions)
def del_img_extensions(ext):
if not isinstance(ext, (list, tuple, set)):
ext = (ext,)
extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext)
_set_extensions(extensions)

@ -1,7 +1,6 @@
import os import os
from .parser_image_folder import ParserImageFolder from .parser_image_folder import ParserImageFolder
from .parser_image_tar import ParserImageTar
from .parser_image_in_tar import ParserImageInTar from .parser_image_in_tar import ParserImageInTar

@ -6,15 +6,35 @@ on the folder hierarchy, just leaf folders by default.
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import os import os
from typing import Dict, List, Optional, Set, Tuple, Union
from timm.utils.misc import natural_key from timm.utils.misc import natural_key
from .parser import Parser
from .class_map import load_class_map from .class_map import load_class_map
from .constants import IMG_EXTENSIONS from .img_extensions import get_img_extensions
from .parser import Parser
def find_images_and_targets(
folder: str,
types: Optional[Union[List, Tuple, Set]] = None,
class_to_idx: Optional[Dict] = None,
leaf_name_only: bool = True,
sort: bool = True
):
""" Walk folder recursively to discover images and map them to classes by folder names.
Args:
folder: root of folder to recrusively search
types: types (file extensions) to search for in path
class_to_idx: specify mapping for class (folder name) to class index if set
leaf_name_only: use only leaf-name of folder walk for class names
sort: re-sort found images by name (for consistent ordering)
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): Returns:
A list of image and target tuples, class_to_idx mapping
"""
types = get_img_extensions(as_set=True) if not types else set(types)
labels = [] labels = []
filenames = [] filenames = []
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
@ -51,7 +71,8 @@ class ParserImageFolder(Parser):
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
if len(self.samples) == 0: if len(self.samples) == 0:
raise RuntimeError( raise RuntimeError(
f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}') f'Found 0 images in subfolders of {root}. '
f'Supported image extensions are {", ".join(get_img_extensions())}')
def __getitem__(self, index): def __getitem__(self, index):
path, target = self.samples[index] path, target = self.samples[index]

@ -9,20 +9,20 @@ Labels are based on the combined folder and/or tar name structure.
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import logging
import os import os
import tarfile
import pickle import pickle
import logging import tarfile
import numpy as np
from glob import glob from glob import glob
from typing import List, Dict from typing import List, Tuple, Dict, Set, Optional, Union
import numpy as np
from timm.utils.misc import natural_key from timm.utils.misc import natural_key
from .parser import Parser
from .class_map import load_class_map from .class_map import load_class_map
from .constants import IMG_EXTENSIONS from .img_extensions import get_img_extensions
from .parser import Parser
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle' CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
@ -39,7 +39,7 @@ class TarState:
self.tf = None self.tf = None
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS): def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions: Set[str]):
sample_count = 0 sample_count = 0
for i, ti in enumerate(tf): for i, ti in enumerate(tf):
if not ti.isfile(): if not ti.isfile():
@ -60,7 +60,14 @@ def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTE
return sample_count return sample_count
def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True): def extract_tarinfos(
root,
class_name_to_idx: Optional[Dict] = None,
cache_tarinfo: Optional[bool] = None,
extensions: Optional[Union[List, Tuple, Set]] = None,
sort: bool = True
):
extensions = get_img_extensions(as_set=True) if not extensions else set(extensions)
root_is_tar = False root_is_tar = False
if os.path.isfile(root): if os.path.isfile(root):
assert os.path.splitext(root)[-1].lower() == '.tar' assert os.path.splitext(root)[-1].lower() == '.tar'
@ -176,8 +183,8 @@ class ParserImageInTar(Parser):
self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos( self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
self.root, self.root,
class_name_to_idx=class_name_to_idx, class_name_to_idx=class_name_to_idx,
cache_tarinfo=cache_tarinfo, cache_tarinfo=cache_tarinfo
extensions=IMG_EXTENSIONS) )
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()} self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
if len(tarfiles) == 1 and tarfiles[0][0] is None: if len(tarfiles) == 1 and tarfiles[0][0] is None:
self.root_is_tar = True self.root_is_tar = True

@ -8,13 +8,15 @@ Hacked together by / Copyright 2020 Ross Wightman
import os import os
import tarfile import tarfile
from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
from timm.utils.misc import natural_key from timm.utils.misc import natural_key
from .class_map import load_class_map
from .img_extensions import get_img_extensions
from .parser import Parser
def extract_tarinfo(tarfile, class_to_idx=None, sort=True): def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
extensions = get_img_extensions(as_set=True)
files = [] files = []
labels = [] labels = []
for ti in tarfile.getmembers(): for ti in tarfile.getmembers():
@ -23,7 +25,7 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
dirname, basename = os.path.split(ti.path) dirname, basename = os.path.split(ti.path)
label = os.path.basename(dirname) label = os.path.basename(dirname)
ext = os.path.splitext(basename)[1] ext = os.path.splitext(basename)[1]
if ext.lower() in IMG_EXTENSIONS: if ext.lower() in extensions:
files.append(ti) files.append(ti)
labels.append(label) labels.append(label)
if class_to_idx is None: if class_to_idx is None:

@ -12,6 +12,7 @@ from .deit import *
from .densenet import * from .densenet import *
from .dla import * from .dla import *
from .dpn import * from .dpn import *
from .edgenext import *
from .efficientnet import * from .efficientnet import *
from .ghostnet import * from .ghostnet import *
from .gluon_resnet import * from .gluon_resnet import *

@ -17,9 +17,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_module
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
from .layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d
from .registry import register_model from .registry import register_model
@ -44,6 +43,7 @@ default_cfgs = dict(
convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
convnext_nano_hnf=_cfg(url=''), convnext_nano_hnf=_cfg(url=''),
convnext_nano_ols=_cfg(url=''),
convnext_tiny_hnf=_cfg( convnext_tiny_hnf=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
crop_pct=0.95), crop_pct=0.95),
@ -88,35 +88,6 @@ default_cfgs = dict(
) )
def _is_contiguous(tensor: torch.Tensor) -> bool:
# jit is oh so lovely :/
# if torch.jit.is_tracing():
# return True
if torch.jit.is_scripting():
return tensor.is_contiguous()
else:
return tensor.is_contiguous(memory_format=torch.contiguous_format)
@register_notrace_module
class LayerNorm2d(nn.LayerNorm):
r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
"""
def __init__(self, normalized_shape, eps=1e-6):
super().__init__(normalized_shape, eps=eps)
def forward(self, x) -> torch.Tensor:
if _is_contiguous(x):
return F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
else:
s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
x = (x - u) * torch.rsqrt(s + self.eps)
x = x * self.weight[:, None, None] + self.bias[:, None, None]
return x
class ConvNeXtBlock(nn.Module): class ConvNeXtBlock(nn.Module):
""" ConvNeXt Block """ ConvNeXt Block
There are two equivalent implementations: There are two equivalent implementations:
@ -133,16 +104,30 @@ class ConvNeXtBlock(nn.Module):
ls_init_value (float): Init value for Layer Scale. Default: 1e-6. ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
""" """
def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=False, mlp_ratio=4, norm_layer=None): def __init__(
self,
dim,
dim_out=None,
stride=1,
mlp_ratio=4,
conv_mlp=False,
conv_bias=True,
ls_init_value=1e-6,
norm_layer=None,
act_layer=nn.GELU,
drop_path=0.,
):
super().__init__() super().__init__()
dim_out = dim_out or dim
if not norm_layer: if not norm_layer:
norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
mlp_layer = ConvMlp if conv_mlp else Mlp mlp_layer = ConvMlp if conv_mlp else Mlp
self.use_conv_mlp = conv_mlp self.use_conv_mlp = conv_mlp
self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.norm = norm_layer(dim) self.conv_dw = create_conv2d(dim, dim_out, kernel_size=7, stride=stride, depthwise=True, bias=conv_bias)
self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=nn.GELU) self.norm = norm_layer(dim_out)
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None self.mlp = mlp_layer(dim_out, int(mlp_ratio * dim_out), act_layer=act_layer)
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x): def forward(self, x):
@ -158,6 +143,7 @@ class ConvNeXtBlock(nn.Module):
x = x.permute(0, 3, 1, 2) x = x.permute(0, 3, 1, 2)
if self.gamma is not None: if self.gamma is not None:
x = x.mul(self.gamma.reshape(1, -1, 1, 1)) x = x.mul(self.gamma.reshape(1, -1, 1, 1))
x = self.drop_path(x) + shortcut x = self.drop_path(x) + shortcut
return x return x
@ -165,25 +151,44 @@ class ConvNeXtBlock(nn.Module):
class ConvNeXtStage(nn.Module): class ConvNeXtStage(nn.Module):
def __init__( def __init__(
self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=False, self,
norm_layer=None, cl_norm_layer=None, cross_stage=False): in_chs,
out_chs,
stride=2,
depth=2,
drop_path_rates=None,
ls_init_value=1.0,
conv_mlp=False,
conv_bias=True,
norm_layer=None,
norm_layer_cl=None
):
super().__init__() super().__init__()
self.grad_checkpointing = False self.grad_checkpointing = False
if in_chs != out_chs or stride > 1: if in_chs != out_chs or stride > 1:
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
norm_layer(in_chs), norm_layer(in_chs),
nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride), nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride, bias=conv_bias),
) )
in_chs = out_chs
else: else:
self.downsample = nn.Identity() self.downsample = nn.Identity()
dp_rates = dp_rates or [0.] * depth drop_path_rates = drop_path_rates or [0.] * depth
self.blocks = nn.Sequential(*[ConvNeXtBlock( stage_blocks = []
dim=out_chs, drop_path=dp_rates[j], ls_init_value=ls_init_value, conv_mlp=conv_mlp, for i in range(depth):
norm_layer=norm_layer if conv_mlp else cl_norm_layer) stage_blocks.append(ConvNeXtBlock(
for j in range(depth)] dim=in_chs,
) dim_out=out_chs,
drop_path=drop_path_rates[i],
ls_init_value=ls_init_value,
conv_mlp=conv_mlp,
conv_bias=conv_bias,
norm_layer=norm_layer if conv_mlp else norm_layer_cl
))
in_chs = out_chs
self.blocks = nn.Sequential(*stage_blocks)
def forward(self, x): def forward(self, x):
x = self.downsample(x) x = self.downsample(x)
@ -210,41 +215,56 @@ class ConvNeXt(nn.Module):
""" """
def __init__( def __init__(
self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4, self,
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=False, stem_type='patch', in_chans=3,
head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0., num_classes=1000,
global_pool='avg',
output_stride=32,
depths=(3, 3, 9, 3),
dims=(96, 192, 384, 768),
ls_init_value=1e-6,
stem_type='patch',
stem_kernel_size=4,
stem_stride=4,
head_init_scale=1.,
head_norm_first=False,
conv_mlp=False,
conv_bias=True,
norm_layer=None,
drop_rate=0.,
drop_path_rate=0.,
): ):
super().__init__() super().__init__()
assert output_stride == 32 assert output_stride == 32
if norm_layer is None: if norm_layer is None:
norm_layer = partial(LayerNorm2d, eps=1e-6) norm_layer = partial(LayerNorm2d, eps=1e-6)
cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
else: else:
assert conv_mlp,\ assert conv_mlp,\
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
cl_norm_layer = norm_layer norm_layer_cl = norm_layer
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.feature_info = [] self.feature_info = []
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 assert stem_type in ('patch', 'overlap')
if stem_type == 'patch': if stem_type == 'patch':
assert stem_kernel_size == stem_stride
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
self.stem = nn.Sequential( self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size), nn.Conv2d(in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, bias=conv_bias),
norm_layer(dims[0]) norm_layer(dims[0])
) )
curr_stride = patch_size
prev_chs = dims[0]
else: else:
self.stem = nn.Sequential( self.stem = nn.Sequential(
nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1), nn.Conv2d(
norm_layer(32), in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride,
nn.GELU(), padding=stem_kernel_size // 2, bias=conv_bias),
nn.Conv2d(32, 64, kernel_size=3, padding=1), norm_layer(dims[0]),
) )
curr_stride = 2 prev_chs = dims[0]
prev_chs = 64 curr_stride = stem_stride
self.stages = nn.Sequential() self.stages = nn.Sequential()
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
@ -256,16 +276,23 @@ class ConvNeXt(nn.Module):
curr_stride *= stride curr_stride *= stride
out_chs = dims[i] out_chs = dims[i]
stages.append(ConvNeXtStage( stages.append(ConvNeXtStage(
prev_chs, out_chs, stride=stride, prev_chs,
depth=depths[i], dp_rates=dp_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp, out_chs,
norm_layer=norm_layer, cl_norm_layer=cl_norm_layer) stride=stride,
) depth=depths[i],
drop_path_rates=dp_rates[i],
ls_init_value=ls_init_value,
conv_mlp=conv_mlp,
conv_bias=conv_bias,
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl
))
prev_chs = out_chs prev_chs = out_chs
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages) self.stages = nn.Sequential(*stages)
self.num_features = prev_chs self.num_features = prev_chs
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
@ -327,10 +354,11 @@ class ConvNeXt(nn.Module):
def _init_weights(module, name=None, head_init_scale=1.0): def _init_weights(module, name=None, head_init_scale=1.0):
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
trunc_normal_(module.weight, std=.02) trunc_normal_(module.weight, std=.02)
nn.init.constant_(module.bias, 0) if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Linear): elif isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02) trunc_normal_(module.weight, std=.02)
nn.init.constant_(module.bias, 0) nn.init.zeros_(module.bias)
if name and 'head.' in name: if name and 'head.' in name:
module.weight.data.mul_(head_init_scale) module.weight.data.mul_(head_init_scale)
module.bias.data.mul_(head_init_scale) module.bias.data.mul_(head_init_scale)
@ -371,14 +399,25 @@ def _create_convnext(variant, pretrained=False, **kwargs):
@register_model @register_model
def convnext_nano_hnf(pretrained=False, **kwargs): def convnext_nano_hnf(pretrained=False, **kwargs):
model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs) model_args = dict(
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs)
model = _create_convnext('convnext_nano_hnf', pretrained=pretrained, **model_args) model = _create_convnext('convnext_nano_hnf', pretrained=pretrained, **model_args)
return model return model
@register_model
def convnext_nano_ols(pretrained=False, **kwargs):
model_args = dict(
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True,
conv_bias=False, stem_type='overlap', stem_kernel_size=9, **kwargs)
model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args)
return model
@register_model @register_model
def convnext_tiny_hnf(pretrained=False, **kwargs): def convnext_tiny_hnf(pretrained=False, **kwargs):
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) model_args = dict(
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs)
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args) model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
return model return model
@ -386,7 +425,7 @@ def convnext_tiny_hnf(pretrained=False, **kwargs):
@register_model @register_model
def convnext_tiny_hnfd(pretrained=False, **kwargs): def convnext_tiny_hnfd(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, stem_type='dual', **kwargs) depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs)
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args) model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
return model return model

File diff suppressed because it is too large Load Diff

@ -1,7 +1,10 @@
""" DeiT - Data-efficient Image Transformers """ DeiT - Data-efficient Image Transformers
DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
paper: `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
paper: `DeiT III: Revenge of the ViT` - https://arxiv.org/abs/2204.07118
Modifications copyright 2021, Ross Wightman Modifications copyright 2021, Ross Wightman
""" """
@ -53,6 +56,46 @@ default_cfgs = {
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
input_size=(3, 384, 384), crop_pct=1.0, input_size=(3, 384, 384), crop_pct=1.0,
classifier=('head', 'head_dist')), classifier=('head', 'head_dist')),
'deit3_small_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'),
'deit3_small_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit3_base_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'),
'deit3_base_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit3_large_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'),
'deit3_large_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit3_huge_patch14_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'),
'deit3_small_patch16_224_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth',
crop_pct=1.0),
'deit3_small_patch16_384_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit3_base_patch16_224_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth',
crop_pct=1.0),
'deit3_base_patch16_384_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit3_large_patch16_224_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth',
crop_pct=1.0),
'deit3_large_patch16_384_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit3_huge_patch14_224_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth',
crop_pct=1.0),
} }
@ -68,9 +111,10 @@ class VisionTransformerDistilled(VisionTransformer):
super().__init__(*args, **kwargs, weight_init='skip') super().__init__(*args, **kwargs, weight_init='skip')
assert self.global_pool in ('token',) assert self.global_pool in ('token',)
self.num_tokens = 2 self.num_prefix_tokens = 2
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim)) self.pos_embed = nn.Parameter(
torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
self.distilled_training = False # must set this True to train w/ distillation token self.distilled_training = False # must set this True to train w/ distillation token
@ -220,3 +264,157 @@ def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
model = _create_deit( model = _create_deit(
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
return model return model
@register_model
def deit3_small_patch16_224(pretrained=False, **kwargs):
""" DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_small_patch16_384(pretrained=False, **kwargs):
""" DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_small_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_base_patch16_224(pretrained=False, **kwargs):
""" DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_base_patch16_384(pretrained=False, **kwargs):
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_base_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_large_patch16_224(pretrained=False, **kwargs):
""" DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_large_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_large_patch16_384(pretrained=False, **kwargs):
""" DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_large_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_huge_patch14_224(pretrained=False, **kwargs):
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_huge_patch14_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_small_patch16_224_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_small_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_small_patch16_384_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_small_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_base_patch16_224_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_base_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_base_patch16_384_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_base_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_large_patch16_224_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_large_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_large_patch16_384_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_large_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_huge_patch14_224_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_huge_patch14_224_in21ft1k', pretrained=pretrained, **model_kwargs)
return model

@ -0,0 +1,557 @@
""" EdgeNeXt
Paper: `EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications`
- https://arxiv.org/abs/2206.10589
Original code and weights from https://github.com/mmaaz60/EdgeNeXt
Modifications and additions for timm by / Copyright 2022, Ross Wightman
"""
import math
import torch
from collections import OrderedDict
from functools import partial
from typing import Tuple
from torch import nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_module
from .layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
from .registry import register_model
__all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
'crop_pct': 0.9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = dict(
edgenext_xx_small=_cfg(
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth"),
edgenext_x_small=_cfg(
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth"),
# edgenext_small=_cfg(
# url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_small.pth"),
edgenext_small=_cfg( # USI weights
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth",
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
),
edgenext_small_rw=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth',
test_input_size=(3, 320, 320), test_crop_pct=1.0,
),
)
@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
class PositionalEncodingFourier(nn.Module):
def __init__(self, hidden_dim=32, dim=768, temperature=10000):
super().__init__()
self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1)
self.scale = 2 * math.pi
self.temperature = temperature
self.hidden_dim = hidden_dim
self.dim = dim
def forward(self, shape: Tuple[int, int, int]):
inv_mask = ~torch.zeros(shape).to(device=self.token_projection.weight.device, dtype=torch.bool)
y_embed = inv_mask.cumsum(1, dtype=torch.float32)
x_embed = inv_mask.cumsum(2, dtype=torch.float32)
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=inv_mask.device)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(),
pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(),
pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
pos = self.token_projection(pos)
return pos
class ConvBlock(nn.Module):
def __init__(
self,
dim,
dim_out=None,
kernel_size=7,
stride=1,
conv_bias=True,
expand_ratio=4,
ls_init_value=1e-6,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU, drop_path=0.,
):
super().__init__()
dim_out = dim_out or dim
self.shortcut_after_dw = stride > 1 or dim != dim_out
self.conv_dw = create_conv2d(
dim, dim_out, kernel_size=kernel_size, stride=stride, depthwise=True, bias=conv_bias)
self.norm = norm_layer(dim_out)
self.mlp = Mlp(dim_out, int(expand_ratio * dim_out), act_layer=act_layer)
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x
x = self.conv_dw(x)
if self.shortcut_after_dw:
shortcut = x
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.mlp(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = shortcut + self.drop_path(x)
return x
class CrossCovarianceAttn(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.
):
super().__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
q, k, v = qkv.unbind(0)
# NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@torch.jit.ignore
def no_weight_decay(self):
return {'temperature'}
class SplitTransposeBlock(nn.Module):
def __init__(
self,
dim,
num_scales=1,
num_heads=8,
expand_ratio=4,
use_pos_emb=True,
conv_bias=True,
qkv_bias=True,
ls_init_value=1e-6,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
drop_path=0.,
attn_drop=0.,
proj_drop=0.
):
super().__init__()
width = max(int(math.ceil(dim / num_scales)), int(math.floor(dim // num_scales)))
self.width = width
self.num_scales = max(1, num_scales - 1)
convs = []
for i in range(self.num_scales):
convs.append(create_conv2d(width, width, kernel_size=3, depthwise=True, bias=conv_bias))
self.convs = nn.ModuleList(convs)
self.pos_embd = None
if use_pos_emb:
self.pos_embd = PositionalEncodingFourier(dim=dim)
self.norm_xca = norm_layer(dim)
self.gamma_xca = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None
self.xca = CrossCovarianceAttn(
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop)
self.norm = norm_layer(dim, eps=1e-6)
self.mlp = Mlp(dim, int(expand_ratio * dim), act_layer=act_layer)
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x
# scales code re-written for torchscript as per my res2net fixes -rw
spx = torch.split(x, self.width, 1)
spo = []
sp = spx[0]
for i, conv in enumerate(self.convs):
if i > 0:
sp = sp + spx[i]
sp = conv(sp)
spo.append(sp)
spo.append(spx[-1])
x = torch.cat(spo, 1)
# XCA
B, C, H, W = x.shape
x = x.reshape(B, C, H * W).permute(0, 2, 1)
if self.pos_embd is not None:
pos_encoding = self.pos_embd((B, H, W)).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
x = x + pos_encoding
x = x + self.drop_path(self.gamma_xca * self.xca(self.norm_xca(x)))
x = x.reshape(B, H, W, C)
# Inverted Bottleneck
x = self.norm(x)
x = self.mlp(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = shortcut + self.drop_path(x)
return x
class EdgeNeXtStage(nn.Module):
def __init__(
self,
in_chs,
out_chs,
stride=2,
depth=2,
num_global_blocks=1,
num_heads=4,
scales=2,
kernel_size=7,
expand_ratio=4,
use_pos_emb=False,
downsample_block=False,
conv_bias=True,
ls_init_value=1.0,
drop_path_rates=None,
norm_layer=LayerNorm2d,
norm_layer_cl=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU
):
super().__init__()
self.grad_checkpointing = False
if downsample_block or stride == 1:
self.downsample = nn.Identity()
else:
self.downsample = nn.Sequential(
norm_layer(in_chs),
nn.Conv2d(in_chs, out_chs, kernel_size=2, stride=2, bias=conv_bias)
)
in_chs = out_chs
stage_blocks = []
for i in range(depth):
if i < depth - num_global_blocks:
stage_blocks.append(
ConvBlock(
dim=in_chs,
dim_out=out_chs,
stride=stride if downsample_block and i == 0 else 1,
conv_bias=conv_bias,
kernel_size=kernel_size,
expand_ratio=expand_ratio,
ls_init_value=ls_init_value,
drop_path=drop_path_rates[i],
norm_layer=norm_layer_cl,
act_layer=act_layer,
)
)
else:
stage_blocks.append(
SplitTransposeBlock(
dim=in_chs,
num_scales=scales,
num_heads=num_heads,
expand_ratio=expand_ratio,
use_pos_emb=use_pos_emb,
conv_bias=conv_bias,
ls_init_value=ls_init_value,
drop_path=drop_path_rates[i],
norm_layer=norm_layer_cl,
act_layer=act_layer,
)
)
in_chs = out_chs
self.blocks = nn.Sequential(*stage_blocks)
def forward(self, x):
x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
class EdgeNeXt(nn.Module):
def __init__(
self,
in_chans=3,
num_classes=1000,
global_pool='avg',
dims=(24, 48, 88, 168),
depths=(3, 3, 9, 3),
global_block_counts=(0, 1, 1, 1),
kernel_sizes=(3, 5, 7, 9),
heads=(8, 8, 8, 8),
d2_scales=(2, 2, 3, 4),
use_pos_emb=(False, True, False, False),
ls_init_value=1e-6,
head_init_scale=1.,
expand_ratio=4,
downsample_block=False,
conv_bias=True,
stem_type='patch',
head_norm_first=False,
act_layer=nn.GELU,
drop_path_rate=0.,
drop_rate=0.,
):
super().__init__()
self.num_classes = num_classes
self.global_pool = global_pool
self.drop_rate = drop_rate
norm_layer = partial(LayerNorm2d, eps=1e-6)
norm_layer_cl = partial(nn.LayerNorm, eps=1e-6)
self.feature_info = []
assert stem_type in ('patch', 'overlap')
if stem_type == 'patch':
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, bias=conv_bias),
norm_layer(dims[0]),
)
else:
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=9, stride=4, padding=9 // 2, bias=conv_bias),
norm_layer(dims[0]),
)
curr_stride = 4
stages = []
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
in_chs = dims[0]
for i in range(4):
stride = 2 if curr_stride == 2 or i > 0 else 1
# FIXME support dilation / output_stride
curr_stride *= stride
stages.append(EdgeNeXtStage(
in_chs=in_chs,
out_chs=dims[i],
stride=stride,
depth=depths[i],
num_global_blocks=global_block_counts[i],
num_heads=heads[i],
drop_path_rates=dp_rates[i],
scales=d2_scales[i],
expand_ratio=expand_ratio,
kernel_size=kernel_sizes[i],
use_pos_emb=use_pos_emb[i],
ls_init_value=ls_init_value,
downsample_block=downsample_block,
conv_bias=conv_bias,
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl,
act_layer=act_layer,
))
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
in_chs = dims[i]
self.feature_info += [dict(num_chs=in_chs, reduction=curr_stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.num_features = dims[-1]
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem',
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+)\.downsample', (0,)), # blocks
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^norm_pre', (99999,))
]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes=0, global_pool=None):
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
x = self.norm_pre(x)
return x
def forward_head(self, x, pre_logits: bool = False):
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
x = self.head.global_pool(x)
x = self.head.norm(x)
x = self.head.flatten(x)
x = self.head.drop(x)
return x if pre_logits else self.head.fc(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _init_weights(module, name=None, head_init_scale=1.0):
if isinstance(module, nn.Conv2d):
trunc_normal_tf_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Linear):
trunc_normal_tf_(module.weight, std=.02)
nn.init.zeros_(module.bias)
if name and 'head.' in name:
module.weight.data.mul_(head_init_scale)
module.bias.data.mul_(head_init_scale)
def checkpoint_filter_fn(state_dict, model):
""" Remap FB checkpoints -> timm """
if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
return state_dict # non-FB checkpoint
# models were released as train checkpoints... :/
if 'model_ema' in state_dict:
state_dict = state_dict['model_ema']
elif 'model' in state_dict:
state_dict = state_dict['model']
elif 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
out_dict = {}
import re
for k, v in state_dict.items():
k = k.replace('downsample_layers.0.', 'stem.')
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
k = k.replace('dwconv', 'conv_dw')
k = k.replace('pwconv', 'mlp.fc')
k = k.replace('head.', 'head.fc.')
if k.startswith('norm.'):
k = k.replace('norm', 'head.norm')
if v.ndim == 2 and 'head' not in k:
model_shape = model.state_dict()[k].shape
v = v.reshape(model_shape)
out_dict[k] = v
return out_dict
def _create_edgenext(variant, pretrained=False, **kwargs):
model = build_model_with_cfg(
EdgeNeXt, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
**kwargs)
return model
@register_model
def edgenext_xx_small(pretrained=False, **kwargs):
# 1.33M & 260.58M @ 256 resolution
# 71.23% Top-1 accuracy
# No AA, Color Jitter=0.4, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler
# Jetson FPS=51.66 versus 47.67 for MobileViT_XXS
# For A100: FPS @ BS=1: 212.13 & @ BS=256: 7042.06 versus FPS @ BS=1: 96.68 & @ BS=256: 4624.71 for MobileViT_XXS
model_kwargs = dict(depths=(2, 2, 6, 2), dims=(24, 48, 88, 168), heads=(4, 4, 4, 4), **kwargs)
return _create_edgenext('edgenext_xx_small', pretrained=pretrained, **model_kwargs)
@register_model
def edgenext_x_small(pretrained=False, **kwargs):
# 2.34M & 538.0M @ 256 resolution
# 75.00% Top-1 accuracy
# No AA, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler
# Jetson FPS=31.61 versus 28.49 for MobileViT_XS
# For A100: FPS @ BS=1: 179.55 & @ BS=256: 4404.95 versus FPS @ BS=1: 94.55 & @ BS=256: 2361.53 for MobileViT_XS
model_kwargs = dict(depths=(3, 3, 9, 3), dims=(32, 64, 100, 192), heads=(4, 4, 4, 4), **kwargs)
return _create_edgenext('edgenext_x_small', pretrained=pretrained, **model_kwargs)
@register_model
def edgenext_small(pretrained=False, **kwargs):
# 5.59M & 1260.59M @ 256 resolution
# 79.43% Top-1 accuracy
# AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
# Jetson FPS=20.47 versus 18.86 for MobileViT_S
# For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
model_kwargs = dict(depths=(3, 3, 9, 3), dims=(48, 96, 160, 304), **kwargs)
return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs)
@register_model
def edgenext_small_rw(pretrained=False, **kwargs):
# 5.59M & 1260.59M @ 256 resolution
# 79.43% Top-1 accuracy
# AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
# Jetson FPS=20.47 versus 18.86 for MobileViT_S
# For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
model_kwargs = dict(
depths=(3, 3, 9, 3), dims=(48, 96, 192, 384),
downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs)
return _create_edgenext('edgenext_small_rw', pretrained=pretrained, **model_kwargs)

@ -25,7 +25,7 @@ from .linear import Linear
from .mixed_conv2d import MixedConv2d from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, LayerNorm2d from .norm import GroupNorm, GroupNorm1, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
from .padding import get_padding, get_same_padding, pad_same from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed from .patch_embed import PatchEmbed
@ -39,4 +39,4 @@ from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .trace_utils import _assert, _float_to_int from .trace_utils import _assert, _float_to_int
from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_ from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_

@ -2,6 +2,7 @@
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import functools
from torch import nn as nn from torch import nn as nn
from .create_conv2d import create_conv2d from .create_conv2d import create_conv2d
@ -40,12 +41,26 @@ class ConvNormAct(nn.Module):
ConvBnAct = ConvNormAct ConvBnAct = ConvNormAct
def create_aa(aa_layer, channels, stride=2, enable=True):
if not aa_layer or not enable:
return nn.Identity()
if isinstance(aa_layer, functools.partial):
if issubclass(aa_layer.func, nn.AvgPool2d):
return aa_layer()
else:
return aa_layer(channels)
elif issubclass(aa_layer, nn.AvgPool2d):
return aa_layer(stride)
else:
return aa_layer(channels=channels, stride=stride)
class ConvNormActAa(nn.Module): class ConvNormActAa(nn.Module):
def __init__( def __init__(
self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None): bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None):
super(ConvNormActAa, self).__init__() super(ConvNormActAa, self).__init__()
use_aa = aa_layer is not None use_aa = aa_layer is not None and stride == 2
self.conv = create_conv2d( self.conv = create_conv2d(
in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
@ -56,7 +71,7 @@ class ConvNormActAa(nn.Module):
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn` # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else nn.Identity() self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
@property @property
def in_channels(self): def in_channels(self):

@ -22,7 +22,7 @@ def get_attn(attn_type):
if isinstance(attn_type, torch.nn.Module): if isinstance(attn_type, torch.nn.Module):
return attn_type return attn_type
module_cls = None module_cls = None
if attn_type is not None: if attn_type:
if isinstance(attn_type, str): if isinstance(attn_type, str):
attn_type = attn_type.lower() attn_type = attn_type.lower()
# Lightweight attention modules (channel and/or coarse spatial). # Lightweight attention modules (channel and/or coarse spatial).

@ -14,11 +14,59 @@ class GroupNorm(nn.GroupNorm):
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
class GroupNorm1(nn.GroupNorm):
""" Group Normalization with 1 group.
Input: tensor in shape [B, C, *]
"""
def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
class LayerNorm2d(nn.LayerNorm): class LayerNorm2d(nn.LayerNorm):
""" LayerNorm for channels of '2D' spatial BCHW tensors """ """ LayerNorm for channels of '2D' spatial NCHW tensors """
def __init__(self, num_channels): def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels) super().__init__(num_channels, eps=eps, elementwise_affine=affine)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm( return F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
def _is_contiguous(tensor: torch.Tensor) -> bool:
# jit is oh so lovely :/
# if torch.jit.is_tracing():
# return True
if torch.jit.is_scripting():
return tensor.is_contiguous()
else:
return tensor.is_contiguous(memory_format=torch.contiguous_format)
@torch.jit.script
def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
x = (x - u) * torch.rsqrt(s + eps)
x = x * weight[:, None, None] + bias[:, None, None]
return x
class LayerNormExp2d(nn.LayerNorm):
""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
Experimental implementation w/ manual norm for tensors non-contiguous tensors.
This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last
layout. However, benefits are not always clear and can perform worse on other GPUs.
"""
def __init__(self, num_channels, eps=1e-6):
super().__init__(num_channels, eps=eps)
def forward(self, x) -> torch.Tensor:
if _is_contiguous(x):
x = F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
else:
x = _layer_norm_cf(x, self.weight, self.bias, self.eps)
return x

@ -36,7 +36,7 @@ class TestTimePoolHead(nn.Module):
return x.view(x.size(0), -1) return x.view(x.size(0), -1)
def apply_test_time_pool(model, config, use_test_size=True): def apply_test_time_pool(model, config, use_test_size=False):
test_time_pool = False test_time_pool = False
if not hasattr(model, 'default_cfg') or not model.default_cfg: if not hasattr(model, 'default_cfg') or not model.default_cfg:
return model, False return model, False

@ -49,6 +49,11 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
with values outside :math:`[a, b]` redrawn until they are within with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`. best when :math:`a \leq \text{mean} \leq b`.
NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
applied while sampling the normal with mean/std applied, therefore a, b args
should be adjusted to match the range of mean, std args.
Args: Args:
tensor: an n-dimensional `torch.Tensor` tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution mean: the mean of the normal distribution
@ -62,6 +67,35 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
return _no_grad_trunc_normal_(tensor, mean, std, a, b) return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
and the result is subsquently scaled and shifted by the mean and std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
_no_grad_trunc_normal_(tensor, 0, 1.0, a, b)
with torch.no_grad():
tensor.mul_(std).add_(mean)
return tensor
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == 'fan_in': if mode == 'fan_in':
@ -75,7 +109,7 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
if distribution == "truncated_normal": if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2) # constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978)
elif distribution == "normal": elif distribution == "normal":
tensor.normal_(std=math.sqrt(variance)) tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform": elif distribution == "uniform":

@ -1,7 +1,8 @@
""" MobileViT """ MobileViT
Paper: Paper:
`MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178 V1: `MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178
V2: `Separable Self-attention for Mobile Vision Transformers` - https://arxiv.org/abs/2206.02680
MobileVitBlock and checkpoints adapted from https://github.com/apple/ml-cvnets (original copyright below) MobileVitBlock and checkpoints adapted from https://github.com/apple/ml-cvnets (original copyright below)
License: https://github.com/apple/ml-cvnets/blob/main/LICENSE (Apple open source) License: https://github.com/apple/ml-cvnets/blob/main/LICENSE (Apple open source)
@ -13,7 +14,7 @@ Rest of code, ByobNet, and Transformer block hacked together by / Copyright 2022
# Copyright (C) 2020 Apple Inc. All Rights Reserved. # Copyright (C) 2020 Apple Inc. All Rights Reserved.
# #
import math import math
from typing import Union, Callable, Dict, Tuple, Optional from typing import Union, Callable, Dict, Tuple, Optional, Sequence
import torch import torch
from torch import nn from torch import nn
@ -21,7 +22,7 @@ import torch.nn.functional as F
from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups
from .fx_features import register_notrace_module from .fx_features import register_notrace_module
from .layers import to_2tuple, make_divisible from .layers import to_2tuple, make_divisible, LayerNorm2d, GroupNorm1, ConvMlp, DropPath
from .vision_transformer import Block as TransformerBlock from .vision_transformer import Block as TransformerBlock
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .registry import register_model from .registry import register_model
@ -48,6 +49,48 @@ default_cfgs = {
'mobilevit_s': _cfg( 'mobilevit_s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_s-38a5a959.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_s-38a5a959.pth'),
'semobilevit_s': _cfg(), 'semobilevit_s': _cfg(),
'mobilevitv2_050': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_050-49951ee2.pth',
crop_pct=0.888),
'mobilevitv2_075': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_075-b5556ef6.pth',
crop_pct=0.888),
'mobilevitv2_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_100-e464ef3b.pth',
crop_pct=0.888),
'mobilevitv2_125': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_125-0ae35027.pth',
crop_pct=0.888),
'mobilevitv2_150': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150-737c5019.pth',
crop_pct=0.888),
'mobilevitv2_175': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175-16462ee2.pth',
crop_pct=0.888),
'mobilevitv2_200': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200-b3422f67.pth',
crop_pct=0.888),
'mobilevitv2_150_in22ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_in22ft1k-0b555d7b.pth',
crop_pct=0.888),
'mobilevitv2_175_in22ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_in22ft1k-4117fa1f.pth',
crop_pct=0.888),
'mobilevitv2_200_in22ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_in22ft1k-1d7c8927.pth',
crop_pct=0.888),
'mobilevitv2_150_384_in22ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_384_in22ft1k-9e142854.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'mobilevitv2_175_384_in22ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_384_in22ft1k-059cbe56.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'mobilevitv2_200_384_in22ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_384_in22ft1k-32c87503.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
} }
@ -72,6 +115,40 @@ def _mobilevit_block(d, c, s, transformer_dim, transformer_depth, patch_size=4,
) )
def _mobilevitv2_block(d, c, s, transformer_depth, patch_size=2, br=2.0, transformer_br=0.5):
# inverted residual + mobilevit blocks as per MobileViT network
return (
_inverted_residual_block(d=d, c=c, s=s, br=br),
ByoBlockCfg(
type='mobilevit2', d=1, c=c, s=1, br=transformer_br, gs=1,
block_kwargs=dict(
transformer_depth=transformer_depth,
patch_size=patch_size)
)
)
def _mobilevitv2_cfg(multiplier=1.0):
chs = (64, 128, 256, 384, 512)
if multiplier != 1.0:
chs = tuple([int(c * multiplier) for c in chs])
cfg = ByoModelCfg(
blocks=(
_inverted_residual_block(d=1, c=chs[0], s=1, br=2.0),
_inverted_residual_block(d=2, c=chs[1], s=2, br=2.0),
_mobilevitv2_block(d=1, c=chs[2], s=2, transformer_depth=2),
_mobilevitv2_block(d=1, c=chs[3], s=2, transformer_depth=4),
_mobilevitv2_block(d=1, c=chs[4], s=2, transformer_depth=3),
),
stem_chs=int(32 * multiplier),
stem_type='3x3',
stem_pool='',
downsample='',
act_layer='silu',
)
return cfg
model_cfgs = dict( model_cfgs = dict(
mobilevit_xxs=ByoModelCfg( mobilevit_xxs=ByoModelCfg(
blocks=( blocks=(
@ -137,11 +214,19 @@ model_cfgs = dict(
attn_kwargs=dict(rd_ratio=1/8), attn_kwargs=dict(rd_ratio=1/8),
num_features=640, num_features=640,
), ),
mobilevitv2_050=_mobilevitv2_cfg(.50),
mobilevitv2_075=_mobilevitv2_cfg(.75),
mobilevitv2_125=_mobilevitv2_cfg(1.25),
mobilevitv2_100=_mobilevitv2_cfg(1.0),
mobilevitv2_150=_mobilevitv2_cfg(1.5),
mobilevitv2_175=_mobilevitv2_cfg(1.75),
mobilevitv2_200=_mobilevitv2_cfg(2.0),
) )
@register_notrace_module @register_notrace_module
class MobileViTBlock(nn.Module): class MobileVitBlock(nn.Module):
""" MobileViT block """ MobileViT block
Paper: https://arxiv.org/abs/2110.02178?context=cs.LG Paper: https://arxiv.org/abs/2110.02178?context=cs.LG
""" """
@ -165,9 +250,9 @@ class MobileViTBlock(nn.Module):
drop_path_rate: float = 0., drop_path_rate: float = 0.,
layers: LayerFn = None, layers: LayerFn = None,
transformer_norm_layer: Callable = nn.LayerNorm, transformer_norm_layer: Callable = nn.LayerNorm,
downsample: str = '' **kwargs, # eat unused args
): ):
super(MobileViTBlock, self).__init__() super(MobileVitBlock, self).__init__()
layers = layers or LayerFn() layers = layers or LayerFn()
groups = num_groups(group_size, in_chs) groups = num_groups(group_size, in_chs)
@ -241,7 +326,270 @@ class MobileViTBlock(nn.Module):
return x return x
register_block('mobilevit', MobileViTBlock) class LinearSelfAttention(nn.Module):
"""
This layer applies a self-attention with linear complexity, as described in `https://arxiv.org/abs/2206.02680`
This layer can be used for self- as well as cross-attention.
Args:
embed_dim (int): :math:`C` from an expected input of size :math:`(N, C, H, W)`
attn_drop (float): Dropout value for context scores. Default: 0.0
bias (bool): Use bias in learnable layers. Default: True
Shape:
- Input: :math:`(N, C, P, N)` where :math:`N` is the batch size, :math:`C` is the input channels,
:math:`P` is the number of pixels in the patch, and :math:`N` is the number of patches
- Output: same as the input
.. note::
For MobileViTv2, we unfold the feature map [B, C, H, W] into [B, C, P, N] where P is the number of pixels
in a patch and N is the number of patches. Because channel is the first dimension in this unfolded tensor,
we use point-wise convolution (instead of a linear layer). This avoids a transpose operation (which may be
expensive on resource-constrained devices) that may be required to convert the unfolded tensor from
channel-first to channel-last format in case of a linear layer.
"""
def __init__(
self,
embed_dim: int,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
self.embed_dim = embed_dim
self.qkv_proj = nn.Conv2d(
in_channels=embed_dim,
out_channels=1 + (2 * embed_dim),
bias=bias,
kernel_size=1,
)
self.attn_drop = nn.Dropout(attn_drop)
self.out_proj = nn.Conv2d(
in_channels=embed_dim,
out_channels=embed_dim,
bias=bias,
kernel_size=1,
)
self.out_drop = nn.Dropout(proj_drop)
def _forward_self_attn(self, x: torch.Tensor) -> torch.Tensor:
# [B, C, P, N] --> [B, h + 2d, P, N]
qkv = self.qkv_proj(x)
# Project x into query, key and value
# Query --> [B, 1, P, N]
# value, key --> [B, d, P, N]
query, key, value = qkv.split([1, self.embed_dim, self.embed_dim], dim=1)
# apply softmax along N dimension
context_scores = F.softmax(query, dim=-1)
context_scores = self.attn_drop(context_scores)
# Compute context vector
# [B, d, P, N] x [B, 1, P, N] -> [B, d, P, N] --> [B, d, P, 1]
context_vector = (key * context_scores).sum(dim=-1, keepdim=True)
# combine context vector with values
# [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
out = F.relu(value) * context_vector.expand_as(value)
out = self.out_proj(out)
out = self.out_drop(out)
return out
@torch.jit.ignore()
def _forward_cross_attn(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
# x --> [B, C, P, N]
# x_prev = [B, C, P, M]
batch_size, in_dim, kv_patch_area, kv_num_patches = x.shape
q_patch_area, q_num_patches = x.shape[-2:]
assert (
kv_patch_area == q_patch_area
), "The number of pixels in a patch for query and key_value should be the same"
# compute query, key, and value
# [B, C, P, M] --> [B, 1 + d, P, M]
qk = F.conv2d(
x_prev,
weight=self.qkv_proj.weight[:self.embed_dim + 1],
bias=self.qkv_proj.bias[:self.embed_dim + 1],
)
# [B, 1 + d, P, M] --> [B, 1, P, M], [B, d, P, M]
query, key = qk.split([1, self.embed_dim], dim=1)
# [B, C, P, N] --> [B, d, P, N]
value = F.conv2d(
x,
weight=self.qkv_proj.weight[self.embed_dim + 1],
bias=self.qkv_proj.bias[self.embed_dim + 1] if self.qkv_proj.bias is not None else None,
)
# apply softmax along M dimension
context_scores = F.softmax(query, dim=-1)
context_scores = self.attn_drop(context_scores)
# compute context vector
# [B, d, P, M] * [B, 1, P, M] -> [B, d, P, M] --> [B, d, P, 1]
context_vector = (key * context_scores).sum(dim=-1, keepdim=True)
# combine context vector with values
# [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
out = F.relu(value) * context_vector.expand_as(value)
out = self.out_proj(out)
out = self.out_drop(out)
return out
def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
if x_prev is None:
return self._forward_self_attn(x)
else:
return self._forward_cross_attn(x, x_prev=x_prev)
class LinearTransformerBlock(nn.Module):
"""
This class defines the pre-norm transformer encoder with linear self-attention in `MobileViTv2 paper <>`_
Args:
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, P, N)`
mlp_ratio (float): Inner dimension ratio of the FFN relative to embed_dim
drop (float): Dropout rate. Default: 0.0
attn_drop (float): Dropout rate for attention in multi-head attention. Default: 0.0
drop_path (float): Stochastic depth rate Default: 0.0
norm_layer (Callable): Normalization layer. Default: layer_norm_2d
Shape:
- Input: :math:`(B, C_{in}, P, N)` where :math:`B` is batch size, :math:`C_{in}` is input embedding dim,
:math:`P` is number of pixels in a patch, and :math:`N` is number of patches,
- Output: same shape as the input
"""
def __init__(
self,
embed_dim: int,
mlp_ratio: float = 2.0,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
act_layer=None,
norm_layer=None,
) -> None:
super().__init__()
act_layer = act_layer or nn.SiLU
norm_layer = norm_layer or GroupNorm1
self.norm1 = norm_layer(embed_dim)
self.attn = LinearSelfAttention(embed_dim=embed_dim, attn_drop=attn_drop, proj_drop=drop)
self.drop_path1 = DropPath(drop_path)
self.norm2 = norm_layer(embed_dim)
self.mlp = ConvMlp(
in_features=embed_dim,
hidden_features=int(embed_dim * mlp_ratio),
act_layer=act_layer,
drop=drop)
self.drop_path2 = DropPath(drop_path)
def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
if x_prev is None:
# self-attention
x = x + self.drop_path1(self.attn(self.norm1(x)))
else:
# cross-attention
res = x
x = self.norm1(x) # norm
x = self.attn(x, x_prev) # attn
x = self.drop_path1(x) + res # residual
# Feed forward network
x = x + self.drop_path2(self.mlp(self.norm2(x)))
return x
@register_notrace_module
class MobileVitV2Block(nn.Module):
"""
This class defines the `MobileViTv2 block <>`_
"""
def __init__(
self,
in_chs: int,
out_chs: Optional[int] = None,
kernel_size: int = 3,
bottle_ratio: float = 1.0,
group_size: Optional[int] = 1,
dilation: Tuple[int, int] = (1, 1),
mlp_ratio: float = 2.0,
transformer_dim: Optional[int] = None,
transformer_depth: int = 2,
patch_size: int = 8,
attn_drop: float = 0.,
drop: int = 0.,
drop_path_rate: float = 0.,
layers: LayerFn = None,
transformer_norm_layer: Callable = GroupNorm1,
**kwargs, # eat unused args
):
super(MobileVitV2Block, self).__init__()
layers = layers or LayerFn()
groups = num_groups(group_size, in_chs)
out_chs = out_chs or in_chs
transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs)
self.conv_kxk = layers.conv_norm_act(
in_chs, in_chs, kernel_size=kernel_size,
stride=1, groups=groups, dilation=dilation[0])
self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False)
self.transformer = nn.Sequential(*[
LinearTransformerBlock(
transformer_dim,
mlp_ratio=mlp_ratio,
attn_drop=attn_drop,
drop=drop,
drop_path=drop_path_rate,
act_layer=layers.act,
norm_layer=transformer_norm_layer
)
for _ in range(transformer_depth)
])
self.norm = transformer_norm_layer(transformer_dim)
self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1, apply_act=False)
self.patch_size = to_2tuple(patch_size)
self.patch_area = self.patch_size[0] * self.patch_size[1]
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
patch_h, patch_w = self.patch_size
new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w
num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w
num_patches = num_patch_h * num_patch_w # N
if new_h != H or new_w != W:
x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=True)
# Local representation
x = self.conv_kxk(x)
x = self.conv_1x1(x)
# Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
C = x.shape[1]
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
x = x.reshape(B, C, -1, num_patches)
# Global representations
x = self.transformer(x)
x = self.norm(x)
# Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
x = self.conv_proj(x)
return x
register_block('mobilevit', MobileVitBlock)
register_block('mobilevit2', MobileVitV2Block)
def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs): def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs):
@ -252,6 +600,14 @@ def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs):
**kwargs) **kwargs)
def _create_mobilevit2(variant, cfg_variant=None, pretrained=False, **kwargs):
return build_model_with_cfg(
ByobNet, variant, pretrained,
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
feature_cfg=dict(flatten_sequential=True),
**kwargs)
@register_model @register_model
def mobilevit_xxs(pretrained=False, **kwargs): def mobilevit_xxs(pretrained=False, **kwargs):
return _create_mobilevit('mobilevit_xxs', pretrained=pretrained, **kwargs) return _create_mobilevit('mobilevit_xxs', pretrained=pretrained, **kwargs)
@ -270,3 +626,74 @@ def mobilevit_s(pretrained=False, **kwargs):
@register_model @register_model
def semobilevit_s(pretrained=False, **kwargs): def semobilevit_s(pretrained=False, **kwargs):
return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs) return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_050(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_050', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_075(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_075', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_100(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_100', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_125(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_125', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_150(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_150', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_175(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_175', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_200(pretrained=False, **kwargs):
return _create_mobilevit('mobilevitv2_200', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_150_in22ft1k(pretrained=False, **kwargs):
return _create_mobilevit(
'mobilevitv2_150_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_175_in22ft1k(pretrained=False, **kwargs):
return _create_mobilevit(
'mobilevitv2_175_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_200_in22ft1k(pretrained=False, **kwargs):
return _create_mobilevit(
'mobilevitv2_200_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_150_384_in22ft1k(pretrained=False, **kwargs):
return _create_mobilevit(
'mobilevitv2_150_384_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_175_384_in22ft1k(pretrained=False, **kwargs):
return _create_mobilevit(
'mobilevitv2_175_384_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs)
@register_model
def mobilevitv2_200_384_in22ft1k(pretrained=False, **kwargs):
return _create_mobilevit(
'mobilevitv2_200_384_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs)

@ -26,7 +26,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, checkpoint_seq from .helpers import build_model_with_cfg, checkpoint_seq
from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp, GroupNorm1
from .registry import register_model from .registry import register_model
@ -80,15 +80,6 @@ class PatchEmbed(nn.Module):
return x return x
class GroupNorm1(nn.GroupNorm):
""" Group Normalization with 1 group.
Input: tensor in shape [B, C, H, W]
"""
def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
class Pooling(nn.Module): class Pooling(nn.Module):
def __init__(self, pool_size=3): def __init__(self, pool_size=3):
super().__init__() super().__init__()

@ -35,6 +35,16 @@ def _cfg(url='', **kwargs):
default_cfgs = { default_cfgs = {
# ResNet and Wide ResNet # ResNet and Wide ResNet
'resnet10t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet10t_176_c3-f3215ab1.pth',
input_size=(3, 176, 176), pool_size=(6, 6),
test_crop_pct=0.95, test_input_size=(3, 224, 224),
first_conv='conv1.0'),
'resnet14t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet14t_176_c3-c4ed2c37.pth',
input_size=(3, 176, 176), pool_size=(6, 6),
test_crop_pct=0.95, test_input_size=(3, 224, 224),
first_conv='conv1.0'),
'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'), 'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'),
'resnet18d': _cfg( 'resnet18d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet18d_ra2-48a79e06.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet18d_ra2-48a79e06.pth',
@ -262,6 +272,9 @@ default_cfgs = {
'resnetblur101d': _cfg( 'resnetblur101d': _cfg(
url='', url='',
interpolation='bicubic', first_conv='conv1.0'), interpolation='bicubic', first_conv='conv1.0'),
'resnetaa50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetaa50_a1h-4cf422b3.pth',
test_input_size=(3, 288, 288), test_crop_pct=1.0, interpolation='bicubic'),
'resnetaa50d': _cfg( 'resnetaa50d': _cfg(
url='', url='',
interpolation='bicubic', first_conv='conv1.0'), interpolation='bicubic', first_conv='conv1.0'),
@ -723,6 +736,24 @@ def _create_resnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(ResNet, variant, pretrained, **kwargs) return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
@register_model
def resnet10t(pretrained=False, **kwargs):
"""Constructs a ResNet-10-T model.
"""
model_args = dict(
block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs)
return _create_resnet('resnet10t', pretrained, **model_args)
@register_model
def resnet14t(pretrained=False, **kwargs):
"""Constructs a ResNet-14-T model.
"""
model_args = dict(
block=Bottleneck, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs)
return _create_resnet('resnet14t', pretrained, **model_args)
@register_model @register_model
def resnet18(pretrained=False, **kwargs): def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model. """Constructs a ResNet-18 model.
@ -1436,6 +1467,14 @@ def resnetblur101d(pretrained=False, **kwargs):
return _create_resnet('resnetblur101d', pretrained, **model_args) return _create_resnet('resnetblur101d', pretrained, **model_args)
@register_model
def resnetaa50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model with avgpool anti-aliasing
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, **kwargs)
return _create_resnet('resnetaa50', pretrained, **model_args)
@register_model @register_model
def resnetaa50d(pretrained=False, **kwargs): def resnetaa50d(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D model with avgpool anti-aliasing """Constructs a ResNet-50-D model with avgpool anti-aliasing

@ -325,8 +325,8 @@ class VisionTransformer(nn.Module):
def __init__( def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
""" """
Args: Args:
img_size (int, tuple): input image size img_size (int, tuple): input image size
@ -360,15 +360,17 @@ class VisionTransformer(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1 if class_token else 0 self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class
self.grad_checkpointing = False self.grad_checkpointing = False
self.patch_embed = embed_layer( self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02) embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
@ -428,11 +430,24 @@ class VisionTransformer(nn.Module):
self.global_pool = global_pool self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def _pos_embed(self, x):
x = self.patch_embed(x) if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + self.pos_embed
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if self.cls_token is not None: if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed) x = x + self.pos_embed
return self.pos_drop(x)
def forward_features(self, x):
x = self.patch_embed(x)
x = self._pos_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x) x = checkpoint_seq(self.blocks, x)
else: else:
@ -442,7 +457,7 @@ class VisionTransformer(nn.Module):
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):
if self.global_pool: if self.global_pool:
x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x) x = self.fc_norm(x)
return x if pre_logits else self.head(x) return x if pre_logits else self.head(x)
@ -556,7 +571,11 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
if pos_embed_w.shape != model.pos_embed.shape: if pos_embed_w.shape != model.pos_embed.shape:
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) pos_embed_w,
model.pos_embed,
getattr(model, 'num_prefix_tokens', 1),
model.patch_embed.grid_size
)
model.pos_embed.copy_(pos_embed_w) model.pos_embed.copy_(pos_embed_w)
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
@ -585,16 +604,16 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from # Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
ntok_new = posemb_new.shape[1] ntok_new = posemb_new.shape[1]
if num_tokens: if num_prefix_tokens:
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:]
ntok_new -= num_tokens ntok_new -= num_prefix_tokens
else: else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0] posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid))) gs_old = int(math.sqrt(len(posemb_grid)))
if not len(gs_new): # backwards compatibility if not len(gs_new): # backwards compatibility
gs_new = [int(math.sqrt(ntok_new))] * 2 gs_new = [int(math.sqrt(ntok_new))] * 2
@ -603,25 +622,34 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1) posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
return posemb return posemb
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
""" convert patch embedding weight from manual patchify + linear proj to conv""" """ convert patch embedding weight from manual patchify + linear proj to conv"""
import re
out_dict = {} out_dict = {}
if 'model' in state_dict: if 'model' in state_dict:
# For deit models # For deit models
state_dict = state_dict['model'] state_dict = state_dict['model']
for k, v in state_dict.items(): for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k and len(v.shape) < 4: if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
# For old models that I trained prior to conv based patchification # For old models that I trained prior to conv based patchification
O, I, H, W = model.patch_embed.proj.weight.shape O, I, H, W = model.patch_embed.proj.weight.shape
v = v.reshape(O, -1, H, W) v = v.reshape(O, -1, H, W)
elif k == 'pos_embed' and v.shape != model.pos_embed.shape: elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
# To resize pos embedding when using model at different size from pretrained weights # To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed( v = resize_pos_embed(
v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) v,
model.pos_embed,
getattr(model, 'num_prefix_tokens', 1),
model.patch_embed.grid_size
)
elif 'gamma_' in k:
# remap layer-scale gamma into sub-module (deit3 models)
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
elif 'pre_logits' in k: elif 'pre_logits' in k:
# NOTE representation layer removed as not used in latest 21k/1k pretrained weights # NOTE representation layer removed as not used in latest 21k/1k pretrained weights
continue continue

@ -8,6 +8,7 @@ import math
import logging import logging
from functools import partial from functools import partial
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
@ -16,7 +17,7 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, named_apply from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple
from .registry import register_model from .registry import register_model
@ -47,9 +48,16 @@ default_cfgs = {
'vit_relpos_base_patch16_224': _cfg( 'vit_relpos_base_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'),
'vit_srelpos_small_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_small_patch16_224-sw-6cdb8849.pth'),
'vit_srelpos_medium_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_medium_patch16_224-sw-ad702b8c.pth'),
'vit_relpos_medium_patch16_cls_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_cls_224-sw-cfe8e259.pth'),
'vit_relpos_base_patch16_cls_224': _cfg( 'vit_relpos_base_patch16_cls_224': _cfg(
url=''), url=''),
'vit_relpos_base_patch16_gapcls_224': _cfg( 'vit_relpos_base_patch16_clsgap_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'),
'vit_relpos_small_patch16_rpn_224': _cfg(url=''), 'vit_relpos_small_patch16_rpn_224': _cfg(url=''),
@ -59,35 +67,43 @@ default_cfgs = {
} }
def gen_relative_position_index(win_size: Tuple[int, int], class_token: int = 0) -> torch.Tensor: def gen_relative_position_index(
# cut and paste w/ modifications from swin / beit codebase q_size: Tuple[int, int],
# cls to token & token 2 cls & cls to cls k_size: Tuple[int, int] = None,
class_token: bool = False) -> torch.Tensor:
# Adapted with significant modifications from Swin / BeiT codebases
# get pair-wise relative position index for each token inside the window # get pair-wise relative position index for each token inside the window
window_area = win_size[0] * win_size[1] q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww
coords = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1) # 2, Wh, Ww if k_size is None:
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww k_coords = q_coords
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 k_size = q_size
relative_coords[:, :, 0] += win_size[0] - 1 # shift to start from 0 else:
relative_coords[:, :, 1] += win_size[1] - 1 # different q vs k sizes is a WIP
relative_coords[:, :, 0] *= 2 * win_size[1] - 1 k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1)
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2
_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0)
if class_token: if class_token:
num_relative_distance = (2 * win_size[0] - 1) * (2 * win_size[1] - 1) + 3 # handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype) # NOTE not intended or tested with MLP log-coords
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1]))
num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
relative_position_index[0, 0:] = num_relative_distance - 3 relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2 relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1 relative_position_index[0, 0] = num_relative_distance - 1
else:
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww return relative_position_index.contiguous()
return relative_position_index
def gen_relative_log_coords( def gen_relative_log_coords(
win_size: Tuple[int, int], win_size: Tuple[int, int],
pretrained_win_size: Tuple[int, int] = (0, 0), pretrained_win_size: Tuple[int, int] = (0, 0),
mode='swin' mode='swin',
): ):
# as per official swin-v2 impl, supporting timm swin-v2-cr coords as well assert mode in ('swin', 'cr', 'rw')
# as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32) relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32) relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
@ -100,12 +116,22 @@ def gen_relative_log_coords(
relative_coords_table[:, :, 0] /= (win_size[0] - 1) relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1) relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # normalize to -8, 8 relative_coords_table *= 8 # normalize to -8, 8
scale = math.log2(8) relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs()) / math.log2(8)
else: else:
# FIXME we should support a form of normalization (to -1/1) for this mode? if mode == 'rw':
scale = math.log2(math.e) # cr w/ window size normalization -> [-1,1] log coords
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # scale to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2( relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs()) / scale 1.0 + relative_coords_table.abs())
relative_coords_table /= math.log2(9) # -> [-1, 1]
else:
# mode == 'cr'
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
1.0 + relative_coords_table.abs())
return relative_coords_table return relative_coords_table
@ -115,19 +141,29 @@ class RelPosMlp(nn.Module):
window_size, window_size,
num_heads=8, num_heads=8,
hidden_dim=128, hidden_dim=128,
class_token=False, prefix_tokens=0,
mode='cr', mode='cr',
pretrained_window_size=(0, 0) pretrained_window_size=(0, 0)
): ):
super().__init__() super().__init__()
self.window_size = window_size self.window_size = window_size
self.window_area = self.window_size[0] * self.window_size[1] self.window_area = self.window_size[0] * self.window_size[1]
self.class_token = 1 if class_token else 0 self.prefix_tokens = prefix_tokens
self.num_heads = num_heads self.num_heads = num_heads
self.bias_shape = (self.window_area,) * 2 + (num_heads,) self.bias_shape = (self.window_area,) * 2 + (num_heads,)
self.apply_sigmoid = mode == 'swin' if mode == 'swin':
self.bias_act = nn.Sigmoid()
self.bias_gain = 16
mlp_bias = (True, False)
elif mode == 'rw':
self.bias_act = nn.Tanh()
self.bias_gain = 4
mlp_bias = True
else:
self.bias_act = nn.Identity()
self.bias_gain = None
mlp_bias = True
mlp_bias = (True, False) if mode == 'swin' else True
self.mlp = Mlp( self.mlp = Mlp(
2, # x, y 2, # x, y
hidden_features=hidden_dim, hidden_features=hidden_dim,
@ -155,10 +191,11 @@ class RelPosMlp(nn.Module):
self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.view(self.bias_shape) relative_position_bias = relative_position_bias.view(self.bias_shape)
relative_position_bias = relative_position_bias.permute(2, 0, 1) relative_position_bias = relative_position_bias.permute(2, 0, 1)
if self.apply_sigmoid: relative_position_bias = self.bias_act(relative_position_bias)
relative_position_bias = 16 * torch.sigmoid(relative_position_bias) if self.bias_gain is not None:
if self.class_token: relative_position_bias = self.bias_gain * relative_position_bias
relative_position_bias = F.pad(relative_position_bias, [self.class_token, 0, self.class_token, 0]) if self.prefix_tokens:
relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0])
return relative_position_bias.unsqueeze(0).contiguous() return relative_position_bias.unsqueeze(0).contiguous()
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
@ -167,18 +204,18 @@ class RelPosMlp(nn.Module):
class RelPosBias(nn.Module): class RelPosBias(nn.Module):
def __init__(self, window_size, num_heads, class_token=False): def __init__(self, window_size, num_heads, prefix_tokens=0):
super().__init__() super().__init__()
assert prefix_tokens <= 1
self.window_size = window_size self.window_size = window_size
self.window_area = window_size[0] * window_size[1] self.window_area = window_size[0] * window_size[1]
self.class_token = 1 if class_token else 0 self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,)
self.bias_shape = (self.window_area + self.class_token,) * 2 + (num_heads,)
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * self.class_token num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads)) self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
self.register_buffer( self.register_buffer(
"relative_position_index", "relative_position_index",
gen_relative_position_index(self.window_size, class_token=self.class_token), gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0),
persistent=False, persistent=False,
) )
@ -306,11 +343,32 @@ class VisionTransformerRelPos(nn.Module):
""" """
def __init__( def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg', self,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-6, img_size=224,
class_token=False, fc_norm=False, rel_pos_type='mlp', shared_rel_pos=False, rel_pos_dim=None, patch_size=16,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', in_chans=3,
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock): num_classes=1000,
global_pool='avg',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
init_values=1e-6,
class_token=False,
fc_norm=False,
rel_pos_type='mlp',
rel_pos_dim=None,
shared_rel_pos=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='skip',
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
block_fn=RelPosBlock
):
""" """
Args: Args:
img_size (int, tuple): input image size img_size (int, tuple): input image size
@ -345,19 +403,22 @@ class VisionTransformerRelPos(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1 if class_token else 0 self.num_prefix_tokens = 1 if class_token else 0
self.grad_checkpointing = False self.grad_checkpointing = False
self.patch_embed = embed_layer( self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
feat_size = self.patch_embed.grid_size feat_size = self.patch_embed.grid_size
rel_pos_args = dict(window_size=feat_size, class_token=class_token) rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens)
if rel_pos_type.startswith('mlp'): if rel_pos_type.startswith('mlp'):
if rel_pos_dim: if rel_pos_dim:
rel_pos_args['hidden_dim'] = rel_pos_dim rel_pos_args['hidden_dim'] = rel_pos_dim
# FIXME experimenting with different relpos log coord configs
if 'swin' in rel_pos_type: if 'swin' in rel_pos_type:
rel_pos_args['mode'] = 'swin' rel_pos_args['mode'] = 'swin'
elif 'rw' in rel_pos_type:
rel_pos_args['mode'] = 'rw'
rel_pos_cls = partial(RelPosMlp, **rel_pos_args) rel_pos_cls = partial(RelPosMlp, **rel_pos_args)
else: else:
rel_pos_cls = partial(RelPosBias, **rel_pos_args) rel_pos_cls = partial(RelPosBias, **rel_pos_args)
@ -367,7 +428,7 @@ class VisionTransformerRelPos(nn.Module):
# NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both... # NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both...
rel_pos_cls = None rel_pos_cls = None
self.cls_token = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) if self.num_tokens else None self.cls_token = nn.Parameter(torch.zeros(1, self.num_prefix_tokens, embed_dim)) if class_token else None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
@ -434,7 +495,7 @@ class VisionTransformerRelPos(nn.Module):
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):
if self.global_pool: if self.global_pool:
x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x) x = self.fc_norm(x)
return x if pre_logits else self.head(x) return x if pre_logits else self.head(x)
@ -502,6 +563,41 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs):
return model return model
@register_model
def vit_srelpos_small_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, fc_norm=False,
rel_pos_dim=384, shared_rel_pos=True, **kwargs)
model = _create_vision_transformer_relpos('vit_srelpos_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_srelpos_medium_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=False,
rel_pos_dim=512, shared_rel_pos=True, **kwargs)
model = _create_vision_transformer_relpos(
'vit_srelpos_medium_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_relpos_medium_patch16_cls_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-M/16) w/ relative log-coord position, class token present
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=False,
rel_pos_dim=256, class_token=True, global_pool='token', **kwargs)
model = _create_vision_transformer_relpos(
'vit_relpos_medium_patch16_cls_224', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs): def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ relative log-coord position, class token present """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
@ -514,14 +610,14 @@ def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs):
@register_model @register_model
def vit_relpos_base_patch16_gapcls_224(pretrained=False, **kwargs): def vit_relpos_base_patch16_clsgap_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ relative log-coord position, class token present """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
NOTE this config is a bit of a mistake, class token was enabled but global avg-pool w/ fc-norm was not disabled NOTE this config is a bit of a mistake, class token was enabled but global avg-pool w/ fc-norm was not disabled
Leaving here for comparisons w/ a future re-train as it performs quite well. Leaving here for comparisons w/ a future re-train as it performs quite well.
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, class_token=True, **kwargs) patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, class_token=True, **kwargs)
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_gapcls_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer_relpos('vit_relpos_base_patch16_clsgap_224', pretrained=pretrained, **model_kwargs)
return model return model

@ -1 +1 @@
__version__ = '0.6.2.dev0' __version__ = '0.6.3.dev0'

@ -38,6 +38,12 @@ try:
except AttributeError: except AttributeError:
pass pass
try:
from functorch.compile import memory_efficient_fusion
has_functorch = True
except ImportError as e:
has_functorch = False
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate') _logger = logging.getLogger('validate')
@ -61,6 +67,8 @@ parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty') metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--input-size', default=None, nargs=3, type=int, parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--use-train-size', action='store_true', default=False,
help='force use of train input size, even when test size is specified in pretrained cfg')
parser.add_argument('--crop-pct', default=None, type=float, parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='Input image center crop pct') metavar='N', help='Input image center crop pct')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
@ -101,8 +109,11 @@ parser.add_argument('--tf-preprocessing', action='store_true', default=False,
help='Use Tensorflow preprocessing pipeline (require CPU TF installed') help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true', parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present') help='use ema version of weights if present')
parser.add_argument('--torchscript', dest='torchscript', action='store_true', scripting_group = parser.add_mutually_exclusive_group()
help='convert model torchscript for inference') scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
help='torch.jit.script the full model')
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
parser.add_argument('--fuser', default='', type=str, parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
@ -155,14 +166,22 @@ def validate(args):
param_count = sum([m.numel() for m in model.parameters()]) param_count = sum([m.numel() for m in model.parameters()])
_logger.info('Model %s created, param count: %d' % (args.model, param_count)) _logger.info('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True) data_config = resolve_data_config(
vars(args),
model=model,
use_test_size=not args.use_train_size,
verbose=True
)
test_time_pool = False test_time_pool = False
if args.test_pool: if args.test_pool:
model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) model, test_time_pool = apply_test_time_pool(model, data_config)
if args.torchscript: if args.torchscript:
torch.jit.optimized_execution(True) torch.jit.optimized_execution(True)
model = torch.jit.script(model) model = torch.jit.trace(model, example_inputs=torch.randn((args.batch_size,) + data_config['input_size']))
if args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model)
model = model.cuda() model = model.cuda()
if args.apex_amp: if args.apex_amp:

Loading…
Cancel
Save