diff --git a/README.md b/README.md index 4c39c692..e4c058f1 100644 --- a/README.md +++ b/README.md @@ -13,16 +13,50 @@ ## Sponsors -A big thank you to my [GitHub Sponsors](https://github.com/sponsors/rwightman) for their support! - -In addition to the sponsors at the link above, I've received hardware and/or cloud resources from +Thanks to the following for hardware support: +* TPU Research Cloud (TRC) (https://sites.research.google/trc/about/) * Nvidia (https://www.nvidia.com/en-us/) -* TFRC (https://www.tensorflow.org/tfrc) -I'm fortunate to be able to dedicate significant time and money of my own supporting this and other open source projects. However, as the projects increase in scope, outside support is needed to continue with the current trajectory of cloud services, hardware, and electricity costs. +And a big thanks to all GitHub sponsors who helped with some of my costs before I joined Hugging Face. ## What's New +### July 8, 2022 +More models, more fixes +* Official research models (w/ weights) added: + * EdgeNeXt from (https://github.com/mmaaz60/EdgeNeXt) + * MobileViT-V2 from (https://github.com/apple/ml-cvnets) + * DeiT III (Revenge of the ViT) from (https://github.com/facebookresearch/deit) +* My own models: + * Small `ResNet` defs added by request with 1 block repeats for both basic and bottleneck (resnet10 and resnet14) + * `CspNet` refactored with dataclass config, simplified CrossStage3 (`cs3`) option. These are closer to YOLO-v5+ backbone defs. + * More relative position vit fiddling. Two `srelpos` (shared relative position) models trained, and a medium w/ class token. + * Add an alternate downsample mode to EdgeNeXt and train a `small` model. Better than original small, but not their new USI trained weights. +* My own model weight results (all ImageNet-1k training) + * `resnet10t` - 66.5 @ 176, 68.3 @ 224 + * `resnet14t` - 71.3 @ 176, 72.3 @ 224 + * `resnetaa50` - 80.6 @ 224 , 81.6 @ 288 + * `darknet53` - 80.0 @ 256, 80.5 @ 288 + * `cs3darknet_m` - 77.0 @ 256, 77.6 @ 288 + * `cs3darknet_focus_m` - 76.7 @ 256, 77.3 @ 288 + * `cs3darknet_l` - 80.4 @ 256, 80.9 @ 288 + * `cs3darknet_focus_l` - 80.3 @ 256, 80.9 @ 288 + * `vit_srelpos_small_patch16_224` - 81.1 @ 224, 82.1 @ 320 + * `vit_srelpos_medium_patch16_224` - 82.3 @ 224, 83.1 @ 320 + * `vit_relpos_small_patch16_cls_224` - 82.6 @ 224, 83.6 @ 320 + * `edgnext_small_rw` - 79.6 @ 224, 80.4 @ 320 +* `cs3`, `darknet`, and `vit_*relpos` weights above all trained on TPU thanks to TRC program! Rest trained on overheating GPUs. +* Hugging Face Hub support fixes verified, demo notebook TBA +* Pretrained weights / configs can be loaded externally (ie from local disk) w/ support for head adaptation. +* Add support to change image extensions scanned by `timm` datasets/parsers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103) +* Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases. + * a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges. + * previous impl exists as `LayerNormExp2d` in `models/layers/norm.py` +* Numerous bug fixes +* Currently testing for imminent PyPi 0.6.x release +* LeViT pretraining of larger models still a WIP, they don't train well / easily without distillation. Time to add distill support (finally)? +* ImageNet-22k weight training + finetune ongoing, work on multi-weight support (slowly) chugging along (there are a LOT of weights, sigh) ... + ### May 13, 2022 * Official Swin-V2 models and weights added from (https://github.com/microsoft/Swin-Transformer). Cleaned up to support torchscript. * Some refactoring for existing `timm` Swin-V2-CR impl, will likely do a bit more to bring parts closer to official and decide whether to merge some aspects. @@ -349,6 +383,7 @@ A full version of the list below with source links can be found in the [document * DenseNet - https://arxiv.org/abs/1608.06993 * DLA - https://arxiv.org/abs/1707.06484 * DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629 +* EdgeNeXt - https://arxiv.org/abs/2206.10589 * EfficientNet (MBConvNet Family) * EfficientNet NoisyStudent (B0-B7, L2) - https://arxiv.org/abs/1911.04252 * EfficientNet AdvProp (B0-B8) - https://arxiv.org/abs/1911.09665 diff --git a/benchmark.py b/benchmark.py index 422da45d..23047bb5 100755 --- a/benchmark.py +++ b/benchmark.py @@ -6,24 +6,23 @@ An inference and train step benchmark script for timm models. Hacked together by Ross Wightman (https://github.com/rwightman) """ import argparse -import os import csv import json -import time import logging -import torch -import torch.nn as nn -import torch.nn.parallel +import time from collections import OrderedDict from contextlib import suppress from functools import partial +import torch +import torch.nn as nn +import torch.nn.parallel + +from timm.data import resolve_data_config from timm.models import create_model, is_model, list_models from timm.optim import create_optimizer_v2 -from timm.data import resolve_data_config from timm.utils import setup_default_logging, set_jit_fuser - has_apex = False try: from apex import amp @@ -51,6 +50,12 @@ except ImportError as e: FlopCountAnalysis = None has_fvcore_profiling = False +try: + from functorch.compile import memory_efficient_fusion + has_functorch = True +except ImportError as e: + has_functorch = False + torch.backends.cudnn.benchmark = True _logger = logging.getLogger('validate') @@ -65,6 +70,8 @@ parser.add_argument('--bench', default='both', type=str, help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'") parser.add_argument('--detail', action='store_true', default=False, help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False') +parser.add_argument('--no-retry', action='store_true', default=False, + help='Do not decay batch size and retry on error.') parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', help='Output csv file for validation results (summary)') parser.add_argument('--num-warm-iter', default=10, type=int, @@ -95,10 +102,13 @@ parser.add_argument('--amp', action='store_true', default=False, help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.') parser.add_argument('--precision', default='float32', type=str, help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)') -parser.add_argument('--torchscript', dest='torchscript', action='store_true', - help='convert model torchscript for inference') parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") +scripting_group = parser.add_mutually_exclusive_group() +scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', + help='convert model torchscript for inference') +scripting_group.add_argument('--aot-autograd', default=False, action='store_true', + help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") # train optimizer parameters @@ -160,10 +170,9 @@ def resolve_precision(precision: str): def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False): - macs, _ = get_model_profile( + _, macs, _ = get_model_profile( model=model, - input_res=(batch_size,) + input_size, # input shape or input to the input_constructor - input_constructor=None, # if specified, a constructor taking input_res is used as input to the model + input_shape=(batch_size,) + input_size, # input shape/resolution print_profile=detailed, # prints the model graph with the measured profile attached to each module detailed=detailed, # print the detailed profile warm_up=10, # the number of warm-ups before measuring the time of each module @@ -188,8 +197,19 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False class BenchmarkRunner: def __init__( - self, model_name, detail=False, device='cuda', torchscript=False, precision='float32', - fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs): + self, + model_name, + detail=False, + device='cuda', + torchscript=False, + aot_autograd=False, + precision='float32', + fuser='', + num_warm_iter=10, + num_bench_iter=50, + use_train_size=False, + **kwargs + ): self.model_name = model_name self.detail = detail self.device = device @@ -216,15 +236,19 @@ class BenchmarkRunner: self.num_classes = self.model.num_classes self.param_count = count_params(self.model) _logger.info('Model %s created, param count: %d' % (model_name, self.param_count)) + + data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) self.scripted = False if torchscript: self.model = torch.jit.script(self.model) self.scripted = True - - data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) self.input_size = data_config['input_size'] self.batch_size = kwargs.pop('batch_size', 256) + if aot_autograd: + assert has_functorch, "functorch is needed for --aot-autograd" + self.model = memory_efficient_fusion(self.model) + self.example_inputs = None self.num_warm_iter = num_warm_iter self.num_bench_iter = num_bench_iter @@ -243,7 +267,13 @@ class BenchmarkRunner: class InferenceBenchmarkRunner(BenchmarkRunner): - def __init__(self, model_name, device='cuda', torchscript=False, **kwargs): + def __init__( + self, + model_name, + device='cuda', + torchscript=False, + **kwargs + ): super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs) self.model.eval() @@ -312,7 +342,13 @@ class InferenceBenchmarkRunner(BenchmarkRunner): class TrainBenchmarkRunner(BenchmarkRunner): - def __init__(self, model_name, device='cuda', torchscript=False, **kwargs): + def __init__( + self, + model_name, + device='cuda', + torchscript=False, + **kwargs + ): super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs) self.model.train() @@ -479,7 +515,7 @@ def decay_batch_exp(batch_size, factor=0.5, divisor=16): return max(0, int(out_batch_size)) -def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs): +def _try_run(model_name, bench_fn, bench_kwargs, initial_batch_size, no_batch_size_retry=False): batch_size = initial_batch_size results = dict() error_str = 'Unknown' @@ -494,8 +530,11 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs): if 'channels_last' in error_str: _logger.error(f'{model_name} not supported in channels_last, skipping.') break - _logger.warning(f'"{error_str}" while running benchmark. Reducing batch size to {batch_size} for retry.') + _logger.error(f'"{error_str}" while running benchmark.') + if no_batch_size_retry: + break batch_size = decay_batch_exp(batch_size) + _logger.warning(f'Reducing batch size to {batch_size} for retry.') results['error'] = error_str return results @@ -537,7 +576,13 @@ def benchmark(args): model_results = OrderedDict(model=model) for prefix, bench_fn in zip(prefixes, bench_fns): - run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs) + run_results = _try_run( + model, + bench_fn, + bench_kwargs=bench_kwargs, + initial_batch_size=batch_size, + no_batch_size_retry=args.no_retry, + ) if prefix and 'error' not in run_results: run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()} model_results.update(run_results) diff --git a/timm/bits/train_setup.py b/timm/bits/train_setup.py index 87c1b5c5..37d0e325 100644 --- a/timm/bits/train_setup.py +++ b/timm/bits/train_setup.py @@ -5,6 +5,7 @@ import logging import torch import torch.nn as nn +from timm.models.layers import convert_sync_batchnorm from timm.optim import create_optimizer_v2 from timm.utils import ModelEmaV2 @@ -65,7 +66,7 @@ def setup_model_and_optimizer( dev_env.to_device(model) if use_syncbn and dev_env.distributed: - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = convert_sync_batchnorm(model) if dev_env.primary: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' diff --git a/timm/data/__init__.py b/timm/data/__init__.py index de978419..2ff68bfb 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -6,7 +6,8 @@ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset_factory import create_dataset from .loader import create_loader_v2 from .mixup import Mixup, FastCollateMixup -from .parsers import create_parser +from .parsers import create_parser,\ + get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions from .real_labels import RealLabelsImagenet from .transforms import RandomResizedCropAndInterpolation, ToTensor, ToNumpy from .transforms_factory import create_transform_v2, create_transform diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index f6649db1..78c879fb 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -36,27 +36,46 @@ _HPARAMS_DEFAULT = dict( img_mean=_FILL, ) -_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) +# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in +# favor of the Image.Resampling enum. The top-level resampling attributes will be +# removed in Pillow 10. +if hasattr(Image, "Resampling"): + _RANDOM_INTERPOLATION = (Image.Resampling.BILINEAR, Image.Resampling.BICUBIC) + _DEFAULT_INTERPOLATION = Image.Resampling.BICUBIC + + _pil_interpolation_to_str = { + Image.Resampling.NEAREST: 'nearest', + Image.Resampling.BILINEAR: 'bilinear', + Image.Resampling.BICUBIC: 'bicubic', + Image.Resampling.BOX: 'box', + Image.Resampling.HAMMING: 'hamming', + Image.Resampling.LANCZOS: 'lanczos', + } +else: + _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + _DEFAULT_INTERPOLATION = Image.BICUBIC + + _pil_interpolation_to_str = { + Image.NEAREST: 'nearest', + Image.BILINEAR: 'bilinear', + Image.BICUBIC: 'bicubic', + Image.BOX: 'box', + Image.HAMMING: 'hamming', + Image.LANCZOS: 'lanczos', + } + +_str_to_pil_interpolation = {b: a for a, b in _pil_interpolation_to_str.items()} def _pil_interp(method): - def _convert(m): - if method == 'bicubic': - return Image.BICUBIC - elif method == 'lanczos': - return Image.LANCZOS - elif method == 'hamming': - return Image.HAMMING - else: - return Image.BILINEAR if isinstance(method, (list, tuple)): - return [_convert(m) if isinstance(m, str) else m for m in method] + return [_str_to_pil_interpolation(m) if isinstance(m, str) else m for m in method] else: - return _convert(method) if isinstance(method, str) else method + return _str_to_pil_interpolation(method) if isinstance(method, str) else method def _interpolation(kwargs): - interpolation = kwargs.pop('resample', Image.BILINEAR) + interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION) if isinstance(interpolation, (list, tuple)): return random.choice(interpolation) else: diff --git a/timm/data/config.py b/timm/data/config.py index f9ed7b6c..2e9b8373 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -107,11 +107,15 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v new_config['std'] = default_cfg['std'] # resolve default crop percentage - new_config['crop_pct'] = DEFAULT_CROP_PCT + crop_pct = DEFAULT_CROP_PCT if 'crop_pct' in args and args['crop_pct'] is not None: - new_config['crop_pct'] = args['crop_pct'] - elif 'crop_pct' in default_cfg: - new_config['crop_pct'] = default_cfg['crop_pct'] + crop_pct = args['crop_pct'] + else: + if use_test_size and 'test_crop_pct' in default_cfg: + crop_pct = default_cfg['test_crop_pct'] + elif 'crop_pct' in default_cfg: + crop_pct = default_cfg['crop_pct'] + new_config['crop_pct'] = crop_pct if getattr(args, 'mixup', 0) > 0 \ or getattr(args, 'cutmix', 0) > 0. \ diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index c4e6cf3c..647357a9 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -26,8 +26,8 @@ _TORCH_BASIC_DS = dict( kmnist=KMNIST, fashion_mnist=FashionMNIST, ) -_TRAIN_SYNONYM = {'train', 'training'} -_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'} +_TRAIN_SYNONYM = dict(train=None, training=None) +_EVAL_SYNONYM = dict(val=None, valid=None, validation=None, eval=None, evaluation=None) def _search_split(root, split): diff --git a/timm/data/parsers/__init__.py b/timm/data/parsers/__init__.py index eeb44e37..4e820d5e 100644 --- a/timm/data/parsers/__init__.py +++ b/timm/data/parsers/__init__.py @@ -1 +1,2 @@ from .parser_factory import create_parser +from .img_extensions import * diff --git a/timm/data/parsers/constants.py b/timm/data/parsers/constants.py deleted file mode 100644 index e7ba484e..00000000 --- a/timm/data/parsers/constants.py +++ /dev/null @@ -1 +0,0 @@ -IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') diff --git a/timm/data/parsers/img_extensions.py b/timm/data/parsers/img_extensions.py new file mode 100644 index 00000000..45c85aab --- /dev/null +++ b/timm/data/parsers/img_extensions.py @@ -0,0 +1,50 @@ +from copy import deepcopy + +__all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions'] + + +IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use +_IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync + + +def _set_extensions(extensions): + global IMG_EXTENSIONS + global _IMG_EXTENSIONS_SET + dedupe = set() # NOTE de-duping tuple while keeping original order + IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x)) + _IMG_EXTENSIONS_SET = set(extensions) + + +def _valid_extension(x: str): + return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.') + + +def is_img_extension(ext): + return ext in _IMG_EXTENSIONS_SET + + +def get_img_extensions(as_set=False): + return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS) + + +def set_img_extensions(extensions): + assert len(extensions) + for x in extensions: + assert _valid_extension(x) + _set_extensions(extensions) + + +def add_img_extensions(ext): + if not isinstance(ext, (list, tuple, set)): + ext = (ext,) + for x in ext: + assert _valid_extension(x) + extensions = IMG_EXTENSIONS + tuple(ext) + _set_extensions(extensions) + + +def del_img_extensions(ext): + if not isinstance(ext, (list, tuple, set)): + ext = (ext,) + extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext) + _set_extensions(extensions) diff --git a/timm/data/parsers/parser_factory.py b/timm/data/parsers/parser_factory.py index 89383d9b..acb8204a 100644 --- a/timm/data/parsers/parser_factory.py +++ b/timm/data/parsers/parser_factory.py @@ -1,7 +1,6 @@ import os from .parser_image_folder import ParserImageFolder -from .parser_image_tar import ParserImageTar from .parser_image_in_tar import ParserImageInTar diff --git a/timm/data/parsers/parser_image_folder.py b/timm/data/parsers/parser_image_folder.py index ed349009..3d22a17b 100644 --- a/timm/data/parsers/parser_image_folder.py +++ b/timm/data/parsers/parser_image_folder.py @@ -6,15 +6,35 @@ on the folder hierarchy, just leaf folders by default. Hacked together by / Copyright 2020 Ross Wightman """ import os +from typing import Dict, List, Optional, Set, Tuple, Union from timm.utils.misc import natural_key -from .parser import Parser from .class_map import load_class_map -from .constants import IMG_EXTENSIONS +from .img_extensions import get_img_extensions +from .parser import Parser + + +def find_images_and_targets( + folder: str, + types: Optional[Union[List, Tuple, Set]] = None, + class_to_idx: Optional[Dict] = None, + leaf_name_only: bool = True, + sort: bool = True +): + """ Walk folder recursively to discover images and map them to classes by folder names. + Args: + folder: root of folder to recrusively search + types: types (file extensions) to search for in path + class_to_idx: specify mapping for class (folder name) to class index if set + leaf_name_only: use only leaf-name of folder walk for class names + sort: re-sort found images by name (for consistent ordering) -def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): + Returns: + A list of image and target tuples, class_to_idx mapping + """ + types = get_img_extensions(as_set=True) if not types else set(types) labels = [] filenames = [] for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): @@ -51,7 +71,8 @@ class ParserImageFolder(Parser): self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) if len(self.samples) == 0: raise RuntimeError( - f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}') + f'Found 0 images in subfolders of {root}. ' + f'Supported image extensions are {", ".join(get_img_extensions())}') def __getitem__(self, index): path, target = self.samples[index] diff --git a/timm/data/parsers/parser_image_in_tar.py b/timm/data/parsers/parser_image_in_tar.py index c6ada962..4fcad797 100644 --- a/timm/data/parsers/parser_image_in_tar.py +++ b/timm/data/parsers/parser_image_in_tar.py @@ -9,20 +9,20 @@ Labels are based on the combined folder and/or tar name structure. Hacked together by / Copyright 2020 Ross Wightman """ +import logging import os -import tarfile import pickle -import logging -import numpy as np +import tarfile from glob import glob -from typing import List, Dict +from typing import List, Tuple, Dict, Set, Optional, Union + +import numpy as np from timm.utils.misc import natural_key -from .parser import Parser from .class_map import load_class_map -from .constants import IMG_EXTENSIONS - +from .img_extensions import get_img_extensions +from .parser import Parser _logger = logging.getLogger(__name__) CACHE_FILENAME_SUFFIX = '_tarinfos.pickle' @@ -39,7 +39,7 @@ class TarState: self.tf = None -def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS): +def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions: Set[str]): sample_count = 0 for i, ti in enumerate(tf): if not ti.isfile(): @@ -60,7 +60,14 @@ def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTE return sample_count -def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True): +def extract_tarinfos( + root, + class_name_to_idx: Optional[Dict] = None, + cache_tarinfo: Optional[bool] = None, + extensions: Optional[Union[List, Tuple, Set]] = None, + sort: bool = True +): + extensions = get_img_extensions(as_set=True) if not extensions else set(extensions) root_is_tar = False if os.path.isfile(root): assert os.path.splitext(root)[-1].lower() == '.tar' @@ -176,8 +183,8 @@ class ParserImageInTar(Parser): self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos( self.root, class_name_to_idx=class_name_to_idx, - cache_tarinfo=cache_tarinfo, - extensions=IMG_EXTENSIONS) + cache_tarinfo=cache_tarinfo + ) self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()} if len(tarfiles) == 1 and tarfiles[0][0] is None: self.root_is_tar = True diff --git a/timm/data/parsers/parser_image_tar.py b/timm/data/parsers/parser_image_tar.py index 467537f4..c2ed429d 100644 --- a/timm/data/parsers/parser_image_tar.py +++ b/timm/data/parsers/parser_image_tar.py @@ -8,13 +8,15 @@ Hacked together by / Copyright 2020 Ross Wightman import os import tarfile -from .parser import Parser -from .class_map import load_class_map -from .constants import IMG_EXTENSIONS from timm.utils.misc import natural_key +from .class_map import load_class_map +from .img_extensions import get_img_extensions +from .parser import Parser + def extract_tarinfo(tarfile, class_to_idx=None, sort=True): + extensions = get_img_extensions(as_set=True) files = [] labels = [] for ti in tarfile.getmembers(): @@ -23,7 +25,7 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True): dirname, basename = os.path.split(ti.path) label = os.path.basename(dirname) ext = os.path.splitext(basename)[1] - if ext.lower() in IMG_EXTENSIONS: + if ext.lower() in extensions: files.append(ti) labels.append(label) if class_to_idx is None: diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 03f0e825..6958829b 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -2,7 +2,6 @@ import torch import torchvision.transforms.functional as F from torchvision.transforms import InterpolationMode -from PIL import Image import warnings import math import random diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 8cb6c70a..195e451b 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -12,6 +12,7 @@ from .deit import * from .densenet import * from .dla import * from .dpn import * +from .edgenext import * from .efficientnet import * from .ghostnet import * from .gluon_resnet import * @@ -61,7 +62,7 @@ from .xcit import * from .factory import create_model, parse_model_name, safe_model_name from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .layers import TestTimePoolHead, apply_test_time_pool -from .layers import convert_splitbn_model +from .layers import convert_splitbn_model, convert_sync_batchnorm from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 1aacef2b..be0c9a66 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -17,9 +17,8 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_module from .helpers import named_apply, build_model_with_cfg, checkpoint_seq -from .layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp +from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d from .registry import register_model @@ -44,6 +43,7 @@ default_cfgs = dict( convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), convnext_nano_hnf=_cfg(url=''), + convnext_nano_ols=_cfg(url=''), convnext_tiny_hnf=_cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', crop_pct=0.95), @@ -88,35 +88,6 @@ default_cfgs = dict( ) -def _is_contiguous(tensor: torch.Tensor) -> bool: - # jit is oh so lovely :/ - # if torch.jit.is_tracing(): - # return True - if torch.jit.is_scripting(): - return tensor.is_contiguous() - else: - return tensor.is_contiguous(memory_format=torch.contiguous_format) - - -@register_notrace_module -class LayerNorm2d(nn.LayerNorm): - r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). - """ - - def __init__(self, normalized_shape, eps=1e-6): - super().__init__(normalized_shape, eps=eps) - - def forward(self, x) -> torch.Tensor: - if _is_contiguous(x): - return F.layer_norm( - x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) - else: - s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) - x = (x - u) * torch.rsqrt(s + self.eps) - x = x * self.weight[:, None, None] + self.bias[:, None, None] - return x - - class ConvNeXtBlock(nn.Module): """ ConvNeXt Block There are two equivalent implementations: @@ -133,16 +104,30 @@ class ConvNeXtBlock(nn.Module): ls_init_value (float): Init value for Layer Scale. Default: 1e-6. """ - def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=False, mlp_ratio=4, norm_layer=None): + def __init__( + self, + dim, + dim_out=None, + stride=1, + mlp_ratio=4, + conv_mlp=False, + conv_bias=True, + ls_init_value=1e-6, + norm_layer=None, + act_layer=nn.GELU, + drop_path=0., + ): super().__init__() + dim_out = dim_out or dim if not norm_layer: norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) mlp_layer = ConvMlp if conv_mlp else Mlp self.use_conv_mlp = conv_mlp - self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv - self.norm = norm_layer(dim) - self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=nn.GELU) - self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None + + self.conv_dw = create_conv2d(dim, dim_out, kernel_size=7, stride=stride, depthwise=True, bias=conv_bias) + self.norm = norm_layer(dim_out) + self.mlp = mlp_layer(dim_out, int(mlp_ratio * dim_out), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): @@ -158,6 +143,7 @@ class ConvNeXtBlock(nn.Module): x = x.permute(0, 3, 1, 2) if self.gamma is not None: x = x.mul(self.gamma.reshape(1, -1, 1, 1)) + x = self.drop_path(x) + shortcut return x @@ -165,25 +151,44 @@ class ConvNeXtBlock(nn.Module): class ConvNeXtStage(nn.Module): def __init__( - self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=False, - norm_layer=None, cl_norm_layer=None, cross_stage=False): + self, + in_chs, + out_chs, + stride=2, + depth=2, + drop_path_rates=None, + ls_init_value=1.0, + conv_mlp=False, + conv_bias=True, + norm_layer=None, + norm_layer_cl=None + ): super().__init__() self.grad_checkpointing = False if in_chs != out_chs or stride > 1: self.downsample = nn.Sequential( norm_layer(in_chs), - nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride), + nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride, bias=conv_bias), ) + in_chs = out_chs else: self.downsample = nn.Identity() - dp_rates = dp_rates or [0.] * depth - self.blocks = nn.Sequential(*[ConvNeXtBlock( - dim=out_chs, drop_path=dp_rates[j], ls_init_value=ls_init_value, conv_mlp=conv_mlp, - norm_layer=norm_layer if conv_mlp else cl_norm_layer) - for j in range(depth)] - ) + drop_path_rates = drop_path_rates or [0.] * depth + stage_blocks = [] + for i in range(depth): + stage_blocks.append(ConvNeXtBlock( + dim=in_chs, + dim_out=out_chs, + drop_path=drop_path_rates[i], + ls_init_value=ls_init_value, + conv_mlp=conv_mlp, + conv_bias=conv_bias, + norm_layer=norm_layer if conv_mlp else norm_layer_cl + )) + in_chs = out_chs + self.blocks = nn.Sequential(*stage_blocks) def forward(self, x): x = self.downsample(x) @@ -210,41 +215,56 @@ class ConvNeXt(nn.Module): """ def __init__( - self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4, - depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=False, stem_type='patch', - head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0., + self, + in_chans=3, + num_classes=1000, + global_pool='avg', + output_stride=32, + depths=(3, 3, 9, 3), + dims=(96, 192, 384, 768), + ls_init_value=1e-6, + stem_type='patch', + stem_kernel_size=4, + stem_stride=4, + head_init_scale=1., + head_norm_first=False, + conv_mlp=False, + conv_bias=True, + norm_layer=None, + drop_rate=0., + drop_path_rate=0., ): super().__init__() assert output_stride == 32 if norm_layer is None: norm_layer = partial(LayerNorm2d, eps=1e-6) - cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) + norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) else: assert conv_mlp,\ 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' - cl_norm_layer = norm_layer + norm_layer_cl = norm_layer self.num_classes = num_classes self.drop_rate = drop_rate self.feature_info = [] - # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 + assert stem_type in ('patch', 'overlap') if stem_type == 'patch': + assert stem_kernel_size == stem_stride + # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 self.stem = nn.Sequential( - nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size), + nn.Conv2d(in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, bias=conv_bias), norm_layer(dims[0]) ) - curr_stride = patch_size - prev_chs = dims[0] else: self.stem = nn.Sequential( - nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1), - norm_layer(32), - nn.GELU(), - nn.Conv2d(32, 64, kernel_size=3, padding=1), + nn.Conv2d( + in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, + padding=stem_kernel_size // 2, bias=conv_bias), + norm_layer(dims[0]), ) - curr_stride = 2 - prev_chs = 64 + prev_chs = dims[0] + curr_stride = stem_stride self.stages = nn.Sequential() dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] @@ -256,16 +276,23 @@ class ConvNeXt(nn.Module): curr_stride *= stride out_chs = dims[i] stages.append(ConvNeXtStage( - prev_chs, out_chs, stride=stride, - depth=depths[i], dp_rates=dp_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp, - norm_layer=norm_layer, cl_norm_layer=cl_norm_layer) - ) + prev_chs, + out_chs, + stride=stride, + depth=depths[i], + drop_path_rates=dp_rates[i], + ls_init_value=ls_init_value, + conv_mlp=conv_mlp, + conv_bias=conv_bias, + norm_layer=norm_layer, + norm_layer_cl=norm_layer_cl + )) prev_chs = out_chs # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) - self.num_features = prev_chs + # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() @@ -327,10 +354,11 @@ class ConvNeXt(nn.Module): def _init_weights(module, name=None, head_init_scale=1.0): if isinstance(module, nn.Conv2d): trunc_normal_(module.weight, std=.02) - nn.init.constant_(module.bias, 0) + if module.bias is not None: + nn.init.zeros_(module.bias) elif isinstance(module, nn.Linear): trunc_normal_(module.weight, std=.02) - nn.init.constant_(module.bias, 0) + nn.init.zeros_(module.bias) if name and 'head.' in name: module.weight.data.mul_(head_init_scale) module.bias.data.mul_(head_init_scale) @@ -371,14 +399,25 @@ def _create_convnext(variant, pretrained=False, **kwargs): @register_model def convnext_nano_hnf(pretrained=False, **kwargs): - model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs) + model_args = dict( + depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs) model = _create_convnext('convnext_nano_hnf', pretrained=pretrained, **model_args) return model +@register_model +def convnext_nano_ols(pretrained=False, **kwargs): + model_args = dict( + depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, + conv_bias=False, stem_type='overlap', stem_kernel_size=9, **kwargs) + model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args) + return model + + @register_model def convnext_tiny_hnf(pretrained=False, **kwargs): - model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) + model_args = dict( + depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args) return model @@ -386,7 +425,7 @@ def convnext_tiny_hnf(pretrained=False, **kwargs): @register_model def convnext_tiny_hnfd(pretrained=False, **kwargs): model_args = dict( - depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, stem_type='dual', **kwargs) + depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args) return model diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index f8a87fab..e8e8910e 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -12,14 +12,18 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage Hacked together by / Copyright 2020 Ross Wightman """ +import collections.abc +from dataclasses import dataclass, field, asdict from functools import partial +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch import torch.nn as nn +import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP -from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, get_norm_act_layer +from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, create_act_layer, make_divisible from .registry import register_model @@ -44,124 +48,315 @@ default_cfgs = { 'cspresnet50w': _cfg(url=''), 'cspresnext50': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnext50_ra_224-648b4713.pth', - input_size=(3, 224, 224), pool_size=(7, 7), crop_pct=0.875 # FIXME I trained this at 224x224, not 256 like ref impl ), - 'cspresnext50_iabn': _cfg(url=''), 'cspdarknet53': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspdarknet53_ra_256-d05c7c21.pth'), - 'cspdarknet53_iabn': _cfg(url=''), - 'darknet53': _cfg(url=''), + + 'darknet17': _cfg(url=''), + 'darknet21': _cfg(url=''), + 'sedarknet21': _cfg(url=''), + 'darknet53': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknet53_256_c2ns-3aeff817.pth', + interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0, + ), + 'darknetaa53': _cfg(url=''), + + 'cs3darknet_s': _cfg( + url='', interpolation='bicubic'), + 'cs3darknet_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_m_c2ns-43f06604.pth', + interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95, + ), + 'cs3darknet_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_l_c2ns-16220c5d.pth', + interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'cs3darknet_x': _cfg( + url=''), + + 'cs3darknet_focus_s': _cfg( + url='', interpolation='bicubic'), + 'cs3darknet_focus_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_m_c2ns-e23bed41.pth', + interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'cs3darknet_focus_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_l_c2ns-65ef8888.pth', + interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'cs3darknet_focus_x': _cfg( + url='', interpolation='bicubic'), + + 'cs3sedarknet_xdw': _cfg( + url='', interpolation='bicubic'), } +@dataclass +class CspStemCfg: + out_chs: Union[int, Tuple[int, ...]] = 32 + stride: Union[int, Tuple[int, ...]] = 2 + kernel_size: int = 3 + padding: Union[int, str] = '' + pool: Optional[str] = '' + + +def _pad_arg(x, n): + # pads an argument tuple to specified n by padding with last value + if not isinstance(x, (tuple, list)): + x = (x,) + curr_n = len(x) + pad_n = n - curr_n + if pad_n <= 0: + return x[:n] + return tuple(x + (x[-1],) * pad_n) + + +@dataclass +class CspStagesCfg: + depth: Tuple[int, ...] = (3, 3, 5, 2) # block depth (number of block repeats in stages) + out_chs: Tuple[int, ...] = (128, 256, 512, 1024) # number of output channels for blocks in stage + stride: Union[int, Tuple[int, ...]] = 2 # stride of stage + groups: Union[int, Tuple[int, ...]] = 1 # num kxk conv groups + block_ratio: Union[float, Tuple[float, ...]] = 1.0 + bottle_ratio: Union[float, Tuple[float, ...]] = 1. # bottleneck-ratio of blocks in stage + avg_down: Union[bool, Tuple[bool, ...]] = False + attn_layer: Optional[Union[str, Tuple[str, ...]]] = None + stage_type: Union[str, Tuple[str]] = 'csp' # stage type ('csp', 'cs2', 'dark') + block_type: Union[str, Tuple[str]] = 'bottle' # blocks type for stages ('bottle', 'dark') + + # cross-stage only + expand_ratio: Union[float, Tuple[float, ...]] = 1.0 + cross_linear: Union[bool, Tuple[bool, ...]] = False + down_growth: Union[bool, Tuple[bool, ...]] = False + + def __post_init__(self): + n = len(self.depth) + assert len(self.out_chs) == n + self.stride = _pad_arg(self.stride, n) + self.groups = _pad_arg(self.groups, n) + self.block_ratio = _pad_arg(self.block_ratio, n) + self.bottle_ratio = _pad_arg(self.bottle_ratio, n) + self.avg_down = _pad_arg(self.avg_down, n) + self.attn_layer = _pad_arg(self.attn_layer, n) + self.stage_type = _pad_arg(self.stage_type, n) + self.block_type = _pad_arg(self.block_type, n) + + self.expand_ratio = _pad_arg(self.expand_ratio, n) + self.cross_linear = _pad_arg(self.cross_linear, n) + self.down_growth = _pad_arg(self.down_growth, n) + + +@dataclass +class CspModelCfg: + stem: CspStemCfg + stages: CspStagesCfg + zero_init_last: bool = True # zero init last weight (usually bn) in residual path + act_layer: str = 'relu' + norm_layer: str = 'batchnorm' + aa_layer: Optional[str] = None # FIXME support string factory for this + + +def _cs3darknet_cfg(width_multiplier=1.0, depth_multiplier=1.0, avg_down=False, act_layer='silu', focus=False): + if focus: + stem_cfg = CspStemCfg( + out_chs=make_divisible(64 * width_multiplier), + kernel_size=6, stride=2, padding=2, pool='') + else: + stem_cfg = CspStemCfg( + out_chs=tuple([make_divisible(c * width_multiplier) for c in (32, 64)]), + kernel_size=3, stride=2, pool='') + return CspModelCfg( + stem=stem_cfg, + stages=CspStagesCfg( + out_chs=tuple([make_divisible(c * width_multiplier) for c in (128, 256, 512, 1024)]), + depth=tuple([int(d * depth_multiplier) for d in (3, 6, 9, 3)]), + stride=2, + bottle_ratio=1., + block_ratio=0.5, + avg_down=avg_down, + stage_type='cs3', + block_type='dark', + ), + act_layer=act_layer, + ) + + model_cfgs = dict( - cspresnet50=dict( - stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'), - stage=dict( - out_chs=(128, 256, 512, 1024), + cspresnet50=CspModelCfg( + stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'), + stages=CspStagesCfg( depth=(3, 3, 5, 2), - stride=(1,) + (2,) * 3, - exp_ratio=(2.,) * 4, - bottle_ratio=(0.5,) * 4, - block_ratio=(1.,) * 4, + out_chs=(128, 256, 512, 1024), + stride=(1, 2), + expand_ratio=2., + bottle_ratio=0.5, cross_linear=True, - ) + ), ), - cspresnet50d=dict( - stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'), - stage=dict( - out_chs=(128, 256, 512, 1024), + cspresnet50d=CspModelCfg( + stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'), + stages=CspStagesCfg( depth=(3, 3, 5, 2), - stride=(1,) + (2,) * 3, - exp_ratio=(2.,) * 4, - bottle_ratio=(0.5,) * 4, - block_ratio=(1.,) * 4, + out_chs=(128, 256, 512, 1024), + stride=(1,) + (2,), + expand_ratio=2., + bottle_ratio=0.5, + block_ratio=1., cross_linear=True, ) ), - cspresnet50w=dict( - stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'), - stage=dict( - out_chs=(256, 512, 1024, 2048), + cspresnet50w=CspModelCfg( + stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'), + stages=CspStagesCfg( depth=(3, 3, 5, 2), - stride=(1,) + (2,) * 3, - exp_ratio=(1.,) * 4, - bottle_ratio=(0.25,) * 4, - block_ratio=(0.5,) * 4, + out_chs=(256, 512, 1024, 2048), + stride=(1,) + (2,), + expand_ratio=1., + bottle_ratio=0.25, + block_ratio=0.5, cross_linear=True, ) ), - cspresnext50=dict( - stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'), - stage=dict( - out_chs=(256, 512, 1024, 2048), + cspresnext50=CspModelCfg( + stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'), + stages=CspStagesCfg( depth=(3, 3, 5, 2), - stride=(1,) + (2,) * 3, - groups=(32,) * 4, - exp_ratio=(1.,) * 4, - bottle_ratio=(1.,) * 4, - block_ratio=(0.5,) * 4, + out_chs=(256, 512, 1024, 2048), + stride=(1,) + (2,), + groups=32, + expand_ratio=1., + bottle_ratio=1., + block_ratio=0.5, cross_linear=True, ) ), - cspdarknet53=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( - out_chs=(64, 128, 256, 512, 1024), + cspdarknet53=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( depth=(1, 2, 8, 8, 4), - stride=(2,) * 5, - exp_ratio=(2.,) + (1.,) * 4, - bottle_ratio=(0.5,) + (1.0,) * 4, - block_ratio=(1.,) + (0.5,) * 4, + out_chs=(64, 128, 256, 512, 1024), + stride=2, + expand_ratio=(2.,) + (1.,), + bottle_ratio=(0.5,) + (1.,), + block_ratio=(1.,) + (0.5,), down_growth=True, - ) + block_type='dark', + ), + act_layer='leaky_relu', ), - darknet53=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( + darknet17=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( + depth=(1,) * 5, out_chs=(64, 128, 256, 512, 1024), + stride=(2,), + bottle_ratio=(0.5,), + block_ratio=(1.,), + stage_type='dark', + block_type='dark', + ), + act_layer='leaky_relu', + ), + darknet21=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( + depth=(1, 1, 1, 2, 2), + out_chs=(64, 128, 256, 512, 1024), + stride=(2,), + bottle_ratio=(0.5,), + block_ratio=(1.,), + stage_type='dark', + block_type='dark', + + ), + act_layer='leaky_relu', + ), + sedarknet21=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( + depth=(1, 1, 1, 2, 2), + out_chs=(64, 128, 256, 512, 1024), + stride=2, + bottle_ratio=0.5, + block_ratio=1., + attn_layer='se', + stage_type='dark', + block_type='dark', + + ), + act_layer='leaky_relu', + ), + darknet53=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( depth=(1, 2, 8, 8, 4), - stride=(2,) * 5, - bottle_ratio=(0.5,) * 5, - block_ratio=(1.,) * 5, - ) - ) -) + out_chs=(64, 128, 256, 512, 1024), + stride=2, + bottle_ratio=0.5, + block_ratio=1., + stage_type='dark', + block_type='dark', + ), + act_layer='leaky_relu', + ), + darknetaa53=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( + depth=(1, 2, 8, 8, 4), + out_chs=(64, 128, 256, 512, 1024), + stride=2, + bottle_ratio=0.5, + block_ratio=1., + avg_down=True, + stage_type='dark', + block_type='dark', + ), + act_layer='leaky_relu', + ), + cs3darknet_s=_cs3darknet_cfg(width_multiplier=0.5, depth_multiplier=0.5), + cs3darknet_m=_cs3darknet_cfg(width_multiplier=0.75, depth_multiplier=0.67), + cs3darknet_l=_cs3darknet_cfg(), + cs3darknet_x=_cs3darknet_cfg(width_multiplier=1.25, depth_multiplier=1.33), -def create_stem( - in_chans=3, out_chs=32, kernel_size=3, stride=2, pool='', - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None): - stem = nn.Sequential() - if not isinstance(out_chs, (tuple, list)): - out_chs = [out_chs] - assert len(out_chs) - in_c = in_chans - for i, out_c in enumerate(out_chs): - conv_name = f'conv{i + 1}' - stem.add_module(conv_name, ConvNormAct( - in_c, out_c, kernel_size, stride=stride if i == 0 else 1, - act_layer=act_layer, norm_layer=norm_layer)) - in_c = out_c - last_conv = conv_name - if pool: - if aa_layer is not None: - stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1)) - stem.add_module('aa', aa_layer(channels=in_c, stride=2)) - else: - stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) - return stem, dict(num_chs=in_c, reduction=stride, module='.'.join(['stem', last_conv])) + cs3darknet_focus_s=_cs3darknet_cfg(width_multiplier=0.5, depth_multiplier=0.5, focus=True), + cs3darknet_focus_m=_cs3darknet_cfg(width_multiplier=0.75, depth_multiplier=0.67, focus=True), + cs3darknet_focus_l=_cs3darknet_cfg(focus=True), + cs3darknet_focus_x=_cs3darknet_cfg(width_multiplier=1.25, depth_multiplier=1.33, focus=True), + cs3sedarknet_xdw=CspModelCfg( + stem=CspStemCfg(out_chs=(32, 64), kernel_size=3, stride=2, pool=''), + stages=CspStagesCfg( + depth=(3, 6, 12, 4), + out_chs=(256, 512, 1024, 2048), + stride=2, + groups=(1, 1, 256, 512), + bottle_ratio=0.5, + block_ratio=0.5, + attn_layer='se', + ), + ), +) -class ResBottleneck(nn.Module): + +class BottleneckBlock(nn.Module): """ ResNe(X)t Bottleneck Block """ def __init__( - self, in_chs, out_chs, dilation=1, bottle_ratio=0.25, groups=1, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_last=False, - attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): - super(ResBottleneck, self).__init__() + self, + in_chs, + out_chs, + dilation=1, + bottle_ratio=0.25, + groups=1, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + attn_last=False, + attn_layer=None, + aa_layer=None, + drop_block=None, + drop_path=0. + ): + super(BottleneckBlock, self).__init__() mid_chs = int(round(out_chs * bottle_ratio)) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) @@ -172,8 +367,8 @@ class ResBottleneck(nn.Module): self.attn2 = create_attn(attn_layer, channels=mid_chs) if not attn_last else None self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs) self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None - self.drop_path = drop_path - self.act3 = act_layer(inplace=True) + self.drop_path = DropPath(drop_path) if drop_path else nn.Identity() + self.act3 = create_act_layer(act_layer) def zero_init_last(self): nn.init.zeros_(self.conv3.bn.weight) @@ -187,9 +382,7 @@ class ResBottleneck(nn.Module): x = self.conv3(x) if self.attn3 is not None: x = self.attn3(x) - if self.drop_path is not None: - x = self.drop_path(x) - x = x + shortcut + x = self.drop_path(x) + shortcut # FIXME partial shortcut needed if first block handled as per original, not used for my current impl #x[:, :shortcut.size(1)] += shortcut x = self.act3(x) @@ -201,9 +394,19 @@ class DarkBlock(nn.Module): """ def __init__( - self, in_chs, out_chs, dilation=1, bottle_ratio=0.5, groups=1, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, - drop_block=None, drop_path=None): + self, + in_chs, + out_chs, + dilation=1, + bottle_ratio=0.5, + groups=1, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + attn_layer=None, + aa_layer=None, + drop_block=None, + drop_path=0. + ): super(DarkBlock, self).__init__() mid_chs = int(round(out_chs * bottle_ratio)) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) @@ -211,8 +414,8 @@ class DarkBlock(nn.Module): self.conv2 = ConvNormActAa( mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups, aa_layer=aa_layer, drop_layer=drop_block, **ckwargs) - self.attn = create_attn(attn_layer, channels=out_chs) - self.drop_path = drop_path + self.attn = create_attn(attn_layer, channels=out_chs, act_layer=act_layer) + self.drop_path = DropPath(drop_path) if drop_path else nn.Identity() def zero_init_last(self): nn.init.zeros_(self.conv2.bn.weight) @@ -223,32 +426,51 @@ class DarkBlock(nn.Module): x = self.conv2(x) if self.attn is not None: x = self.attn(x) - if self.drop_path is not None: - x = self.drop_path(x) - x = x + shortcut + x = self.drop_path(x) + shortcut return x class CrossStage(nn.Module): """Cross Stage.""" def __init__( - self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1., - groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None, - block_fn=ResBottleneck, **block_kwargs): + self, + in_chs, + out_chs, + stride, + dilation, + depth, + block_ratio=1., + bottle_ratio=1., + expand_ratio=1., + groups=1, + first_dilation=None, + avg_down=False, + down_growth=False, + cross_linear=False, + block_dpr=None, + block_fn=BottleneckBlock, + **block_kwargs + ): super(CrossStage, self).__init__() first_dilation = first_dilation or dilation down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels - exp_chs = int(round(out_chs * exp_ratio)) + self.expand_chs = exp_chs = int(round(out_chs * expand_ratio)) block_out_chs = int(round(out_chs * block_ratio)) conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) if stride != 1 or first_dilation != dilation: - self.conv_down = ConvNormActAa( - in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, - aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs) + if avg_down: + self.conv_down = nn.Sequential( + nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling + ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) + ) + else: + self.conv_down = ConvNormActAa( + in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, + aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs) prev_chs = down_chs else: - self.conv_down = None + self.conv_down = nn.Identity() prev_chs = in_chs # FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also, @@ -259,9 +481,15 @@ class CrossStage(nn.Module): self.blocks = nn.Sequential() for i in range(depth): - drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None self.blocks.add_module(str(i), block_fn( - prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + in_chs=prev_chs, + out_chs=block_out_chs, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + drop_path=block_dpr[i] if block_dpr is not None else 0., + **block_kwargs + )) prev_chs = block_out_chs # transition convs @@ -269,38 +497,135 @@ class CrossStage(nn.Module): self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs) def forward(self, x): - if self.conv_down is not None: - x = self.conv_down(x) + x = self.conv_down(x) x = self.conv_exp(x) - split = x.shape[1] // 2 - xs, xb = x[:, :split], x[:, split:] + xs, xb = x.split(self.expand_chs // 2, dim=1) xb = self.blocks(xb) xb = self.conv_transition_b(xb).contiguous() out = self.conv_transition(torch.cat([xs, xb], dim=1)) return out +class CrossStage3(nn.Module): + """Cross Stage 3. + Similar to CrossStage, but with only one transition conv for the output. + """ + def __init__( + self, + in_chs, + out_chs, + stride, + dilation, + depth, + block_ratio=1., + bottle_ratio=1., + expand_ratio=1., + groups=1, + first_dilation=None, + avg_down=False, + down_growth=False, + cross_linear=False, + block_dpr=None, + block_fn=BottleneckBlock, + **block_kwargs + ): + super(CrossStage3, self).__init__() + first_dilation = first_dilation or dilation + down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels + self.expand_chs = exp_chs = int(round(out_chs * expand_ratio)) + block_out_chs = int(round(out_chs * block_ratio)) + conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) + + if stride != 1 or first_dilation != dilation: + if avg_down: + self.conv_down = nn.Sequential( + nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling + ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) + ) + else: + self.conv_down = ConvNormActAa( + in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, + aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs) + prev_chs = down_chs + else: + self.conv_down = None + prev_chs = in_chs + + # expansion conv + self.conv_exp = ConvNormAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs) + prev_chs = exp_chs // 2 # expanded output is split in 2 for blocks and cross stage + + self.blocks = nn.Sequential() + for i in range(depth): + self.blocks.add_module(str(i), block_fn( + in_chs=prev_chs, + out_chs=block_out_chs, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + drop_path=block_dpr[i] if block_dpr is not None else 0., + **block_kwargs + )) + prev_chs = block_out_chs + + # transition convs + self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs) + + def forward(self, x): + x = self.conv_down(x) + x = self.conv_exp(x) + x1, x2 = x.split(self.expand_chs // 2, dim=1) + x1 = self.blocks(x1) + out = self.conv_transition(torch.cat([x1, x2], dim=1)) + return out + + class DarkStage(nn.Module): """DarkNet stage.""" def __init__( - self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., groups=1, - first_dilation=None, block_fn=ResBottleneck, block_dpr=None, **block_kwargs): + self, + in_chs, + out_chs, + stride, + dilation, + depth, + block_ratio=1., + bottle_ratio=1., + groups=1, + first_dilation=None, + avg_down=False, + block_fn=BottleneckBlock, + block_dpr=None, + **block_kwargs + ): super(DarkStage, self).__init__() first_dilation = first_dilation or dilation + conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) - self.conv_down = ConvNormActAa( - in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, - act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'), - aa_layer=block_kwargs.get('aa_layer', None)) + if avg_down: + self.conv_down = nn.Sequential( + nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling + ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) + ) + else: + self.conv_down = ConvNormActAa( + in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, + aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs) prev_chs = out_chs block_out_chs = int(round(out_chs * block_ratio)) self.blocks = nn.Sequential() for i in range(depth): - drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None self.blocks.add_module(str(i), block_fn( - prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + in_chs=prev_chs, + out_chs=block_out_chs, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + drop_path=block_dpr[i] if block_dpr is not None else 0., + **block_kwargs + )) prev_chs = block_out_chs def forward(self, x): @@ -309,36 +634,131 @@ class DarkStage(nn.Module): return x -def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.): - # get per stage args for stage and containing blocks, calculate strides to meet target output_stride - num_stages = len(cfg['depth']) - if 'groups' not in cfg: - cfg['groups'] = (1,) * num_stages - if 'down_growth' in cfg and not isinstance(cfg['down_growth'], (list, tuple)): - cfg['down_growth'] = (cfg['down_growth'],) * num_stages - if 'cross_linear' in cfg and not isinstance(cfg['cross_linear'], (list, tuple)): - cfg['cross_linear'] = (cfg['cross_linear'],) * num_stages - cfg['block_dpr'] = [None] * num_stages if not drop_path_rate else \ - [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg['depth'])).split(cfg['depth'])] - stage_strides = [] - stage_dilations = [] - stage_first_dilations = [] +def create_csp_stem( + in_chans=3, + out_chs=32, + kernel_size=3, + stride=2, + pool='', + padding='', + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + aa_layer=None +): + stem = nn.Sequential() + feature_info = [] + if not isinstance(out_chs, (tuple, list)): + out_chs = [out_chs] + stem_depth = len(out_chs) + assert stem_depth + assert stride in (1, 2, 4) + prev_feat = None + prev_chs = in_chans + last_idx = stem_depth - 1 + stem_stride = 1 + for i, chs in enumerate(out_chs): + conv_name = f'conv{i + 1}' + conv_stride = 2 if (i == 0 and stride > 1) or (i == last_idx and stride > 2 and not pool) else 1 + if conv_stride > 1 and prev_feat is not None: + feature_info.append(prev_feat) + stem.add_module(conv_name, ConvNormAct( + prev_chs, chs, kernel_size, + stride=conv_stride, + padding=padding if i == 0 else '', + act_layer=act_layer, + norm_layer=norm_layer + )) + stem_stride *= conv_stride + prev_chs = chs + prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', conv_name])) + if pool: + assert stride > 2 + if prev_feat is not None: + feature_info.append(prev_feat) + if aa_layer is not None: + stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1)) + stem.add_module('aa', aa_layer(channels=prev_chs, stride=2)) + pool_name = 'aa' + else: + stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + pool_name = 'pool' + stem_stride *= 2 + prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', pool_name])) + feature_info.append(prev_feat) + return stem, feature_info + + +def _get_stage_fn(stage_type: str, stage_args): + assert stage_type in ('dark', 'csp', 'cs3') + if stage_type == 'dark': + stage_args.pop('expand_ratio', None) + stage_args.pop('cross_linear', None) + stage_args.pop('down_growth', None) + stage_fn = DarkStage + elif stage_type == 'csp': + stage_fn = CrossStage + else: + stage_fn = CrossStage3 + return stage_fn, stage_args + + +def _get_block_fn(stage_type: str, stage_args): + assert stage_type in ('dark', 'bottle') + if stage_type == 'dark': + return DarkBlock, stage_args + else: + return BottleneckBlock, stage_args + + +def create_csp_stages( + cfg: CspModelCfg, + drop_path_rate: float, + output_stride: int, + stem_feat: Dict[str, Any] +): + cfg_dict = asdict(cfg.stages) + num_stages = len(cfg.stages.depth) + cfg_dict['block_dpr'] = [None] * num_stages if not drop_path_rate else \ + [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.stages.depth)).split(cfg.stages.depth)] + stage_args = [dict(zip(cfg_dict.keys(), values)) for values in zip(*cfg_dict.values())] + block_kwargs = dict( + act_layer=cfg.act_layer, + norm_layer=cfg.norm_layer, + aa_layer=cfg.aa_layer + ) + dilation = 1 - for cfg_stride in cfg['stride']: - stage_first_dilations.append(dilation) - if curr_stride >= output_stride: - dilation *= cfg_stride + net_stride = stem_feat['reduction'] + prev_chs = stem_feat['num_chs'] + prev_feat = stem_feat + feature_info = [] + stages = [] + for stage_idx, stage_args in enumerate(stage_args): + stage_fn, stage_args = _get_stage_fn(stage_args.pop('stage_type'), stage_args) + block_fn, stage_args = _get_block_fn(stage_args.pop('block_type'), stage_args) + stride = stage_args.pop('stride') + if stride != 1 and prev_feat: + feature_info.append(prev_feat) + if net_stride >= output_stride and stride > 1: + dilation *= stride stride = 1 - else: - stride = cfg_stride - curr_stride *= stride - stage_strides.append(stride) - stage_dilations.append(dilation) - cfg['stride'] = stage_strides - cfg['dilation'] = stage_dilations - cfg['first_dilation'] = stage_first_dilations - stage_args = [dict(zip(cfg.keys(), values)) for values in zip(*cfg.values())] - return stage_args + net_stride *= stride + first_dilation = 1 if dilation in (1, 2) else 2 + + stages += [stage_fn( + prev_chs, + **stage_args, + stride=stride, + first_dilation=first_dilation, + dilation=dilation, + block_fn=block_fn, + **block_kwargs, + )] + prev_chs = stage_args['out_chs'] + prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}') + + feature_info.append(prev_feat) + return nn.Sequential(*stages), feature_info class CspNet(nn.Module): @@ -352,33 +772,40 @@ class CspNet(nn.Module): """ def __init__( - self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0., - act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_path_rate=0., - zero_init_last=True, stage_fn=CrossStage, block_fn=ResBottleneck): + self, + cfg: CspModelCfg, + in_chans=3, + num_classes=1000, + output_stride=32, + global_pool='avg', + drop_rate=0., + drop_path_rate=0., + zero_init_last=True + ): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate assert output_stride in (8, 16, 32) - layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) + layer_args = dict( + act_layer=cfg.act_layer, + norm_layer=cfg.norm_layer, + aa_layer=cfg.aa_layer + ) + self.feature_info = [] # Construct the stem - self.stem, stem_feat_info = create_stem(in_chans, **cfg['stem'], **layer_args) - self.feature_info = [stem_feat_info] - prev_chs = stem_feat_info['num_chs'] - curr_stride = stem_feat_info['reduction'] # reduction does not include pool - if cfg['stem']['pool']: - curr_stride *= 2 + self.stem, stem_feat_info = create_csp_stem(in_chans, **asdict(cfg.stem), **layer_args) + self.feature_info.extend(stem_feat_info[:-1]) # Construct the stages - per_stage_args = _cfg_to_stage_args( - cfg['stage'], curr_stride=curr_stride, output_stride=output_stride, drop_path_rate=drop_path_rate) - self.stages = nn.Sequential() - for i, sa in enumerate(per_stage_args): - self.stages.add_module( - str(i), stage_fn(prev_chs, **sa, **layer_args, block_fn=block_fn)) - prev_chs = sa['out_chs'] - curr_stride *= sa['stride'] - self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] + self.stages, stage_feat_info = create_csp_stages( + cfg, + drop_path_rate=drop_path_rate, + output_stride=output_stride, + stem_feat=stem_feat_info[-1], + ) + prev_chs = stage_feat_info[-1]['num_chs'] + self.feature_info.extend(stage_feat_info) # Construct the head self.num_features = prev_chs @@ -427,23 +854,26 @@ class CspNet(nn.Module): def _init_weights(module, name, zero_init_last=False): if isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(module, nn.BatchNorm2d): - nn.init.ones_(module.weight) - nn.init.zeros_(module.bias) + if module.bias is not None: + nn.init.zeros_(module.bias) elif isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.01) - nn.init.zeros_(module.bias) + if module.bias is not None: + nn.init.zeros_(module.bias) elif zero_init_last and hasattr(module, 'zero_init_last'): module.zero_init_last() def _create_cspnet(variant, pretrained=False, **kwargs): - cfg_variant = variant.split('_')[0] - # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5] - out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4, 5) if 'darknet' in variant else (0, 1, 2, 3, 4)) + if variant.startswith('darknet') or variant.startswith('cspdarknet'): + # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5] + default_out_indices = (0, 1, 2, 3, 4, 5) + else: + default_out_indices = (0, 1, 2, 3, 4) + out_indices = kwargs.pop('out_indices', default_out_indices) return build_model_with_cfg( CspNet, variant, pretrained, - model_cfg=model_cfgs[cfg_variant], + model_cfg=model_cfgs[variant], feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs) @@ -469,22 +899,75 @@ def cspresnext50(pretrained=False, **kwargs): @register_model -def cspresnext50_iabn(pretrained=False, **kwargs): - norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu') - return _create_cspnet('cspresnext50_iabn', pretrained=pretrained, norm_layer=norm_layer, **kwargs) +def cspdarknet53(pretrained=False, **kwargs): + return _create_cspnet('cspdarknet53', pretrained=pretrained, **kwargs) @register_model -def cspdarknet53(pretrained=False, **kwargs): - return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs) +def darknet17(pretrained=False, **kwargs): + return _create_cspnet('darknet17', pretrained=pretrained, **kwargs) @register_model -def cspdarknet53_iabn(pretrained=False, **kwargs): - norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu') - return _create_cspnet('cspdarknet53_iabn', pretrained=pretrained, block_fn=DarkBlock, norm_layer=norm_layer, **kwargs) +def darknet21(pretrained=False, **kwargs): + return _create_cspnet('darknet21', pretrained=pretrained, **kwargs) + + +@register_model +def sedarknet21(pretrained=False, **kwargs): + return _create_cspnet('sedarknet21', pretrained=pretrained, **kwargs) @register_model def darknet53(pretrained=False, **kwargs): - return _create_cspnet('darknet53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + return _create_cspnet('darknet53', pretrained=pretrained, **kwargs) + + +@register_model +def darknetaa53(pretrained=False, **kwargs): + return _create_cspnet('darknetaa53', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_s(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_s', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_m(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_m', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_l(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_l', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_x(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_x', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_focus_s(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_focus_s', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_focus_m(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_focus_m', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_focus_l(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_focus_l', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_focus_x(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_focus_x', pretrained=pretrained, **kwargs) + + +@register_model +def cs3sedarknet_xdw(pretrained=False, **kwargs): + return _create_cspnet('cs3sedarknet_xdw', pretrained=pretrained, **kwargs) diff --git a/timm/models/deit.py b/timm/models/deit.py index e6b4b025..8cb36bd6 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -1,12 +1,17 @@ """ DeiT - Data-efficient Image Transformers DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below -paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +paper: `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +paper: `DeiT III: Revenge of the ViT` - https://arxiv.org/abs/2204.07118 Modifications copyright 2021, Ross Wightman """ # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. +from functools import partial + import torch from torch import nn as nn @@ -53,6 +58,46 @@ default_cfgs = { url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')), + + 'deit3_small_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'), + 'deit3_small_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_base_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'), + 'deit3_base_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_large_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'), + 'deit3_large_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_huge_patch14_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'), + + 'deit3_small_patch16_224_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth', + crop_pct=1.0), + 'deit3_small_patch16_384_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_base_patch16_224_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth', + crop_pct=1.0), + 'deit3_base_patch16_384_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_large_patch16_224_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth', + crop_pct=1.0), + 'deit3_large_patch16_384_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_huge_patch14_224_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth', + crop_pct=1.0), } @@ -68,9 +113,10 @@ class VisionTransformerDistilled(VisionTransformer): super().__init__(*args, **kwargs, weight_init='skip') assert self.global_pool in ('token',) - self.num_tokens = 2 + self.num_prefix_tokens = 2 self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim)) self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() self.distilled_training = False # must set this True to train w/ distillation token @@ -133,7 +179,7 @@ def _create_deit(variant, pretrained=False, distilled=False, **kwargs): model_cls = VisionTransformerDistilled if distilled else VisionTransformer model = build_model_with_cfg( model_cls, variant, pretrained, - pretrained_filter_fn=checkpoint_filter_fn, + pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True), **kwargs) return model @@ -220,3 +266,157 @@ def deit_base_distilled_patch16_384(pretrained=False, **kwargs): model = _create_deit( 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) return model + + +@register_model +def deit3_small_patch16_224(pretrained=False, **kwargs): + """ DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_small_patch16_384(pretrained=False, **kwargs): + """ DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_small_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_base_patch16_224(pretrained=False, **kwargs): + """ DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_base_patch16_384(pretrained=False, **kwargs): + """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_base_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_large_patch16_224(pretrained=False, **kwargs): + """ DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_large_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_large_patch16_384(pretrained=False, **kwargs): + """ DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_large_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_huge_patch14_224(pretrained=False, **kwargs): + """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_huge_patch14_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_small_patch16_224_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_small_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_small_patch16_384_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_small_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_base_patch16_224_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_base_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_base_patch16_384_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_base_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_large_patch16_224_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_large_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_large_patch16_384_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_large_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_huge_patch14_224_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_huge_patch14_224_in21ft1k', pretrained=pretrained, **model_kwargs) + return model diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py new file mode 100644 index 00000000..a81bd9b0 --- /dev/null +++ b/timm/models/edgenext.py @@ -0,0 +1,559 @@ +""" EdgeNeXt + +Paper: `EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications` + - https://arxiv.org/abs/2206.10589 + +Original code and weights from https://github.com/mmaaz60/EdgeNeXt + +Modifications and additions for timm by / Copyright 2022, Ross Wightman +""" +import math +import torch +from collections import OrderedDict +from functools import partial +from typing import Tuple + +from torch import nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_module +from .layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d +from .helpers import named_apply, build_model_with_cfg, checkpoint_seq +from .registry import register_model + + +__all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), + 'crop_pct': 0.9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = dict( + edgenext_xx_small=_cfg( + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + edgenext_x_small=_cfg( + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + # edgenext_small=_cfg( + # url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_small.pth"), + edgenext_small=_cfg( # USI weights + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth", + crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, + ), + + edgenext_small_rw=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth', + test_input_size=(3, 320, 320), test_crop_pct=1.0, + ), +) + + +@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method +class PositionalEncodingFourier(nn.Module): + def __init__(self, hidden_dim=32, dim=768, temperature=10000): + super().__init__() + self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1) + self.scale = 2 * math.pi + self.temperature = temperature + self.hidden_dim = hidden_dim + self.dim = dim + + def forward(self, shape: Tuple[int, int, int]): + inv_mask = ~torch.zeros(shape).to(device=self.token_projection.weight.device, dtype=torch.bool) + y_embed = inv_mask.cumsum(1, dtype=torch.float32) + x_embed = inv_mask.cumsum(2, dtype=torch.float32) + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=inv_mask.device) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), + pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), + pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + pos = self.token_projection(pos) + + return pos + + +class ConvBlock(nn.Module): + def __init__( + self, + dim, + dim_out=None, + kernel_size=7, + stride=1, + conv_bias=True, + expand_ratio=4, + ls_init_value=1e-6, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, drop_path=0., + ): + super().__init__() + dim_out = dim_out or dim + self.shortcut_after_dw = stride > 1 or dim != dim_out + + self.conv_dw = create_conv2d( + dim, dim_out, kernel_size=kernel_size, stride=stride, depthwise=True, bias=conv_bias) + self.norm = norm_layer(dim_out) + self.mlp = Mlp(dim_out, int(expand_ratio * dim_out), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + x = self.conv_dw(x) + if self.shortcut_after_dw: + shortcut = x + + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.mlp(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = shortcut + self.drop_path(x) + return x + + +class CrossCovarianceAttn(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0., + proj_drop=0. + ): + super().__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1) + q, k, v = qkv.unbind(0) + + # NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map + attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + @torch.jit.ignore + def no_weight_decay(self): + return {'temperature'} + + +class SplitTransposeBlock(nn.Module): + def __init__( + self, + dim, + num_scales=1, + num_heads=8, + expand_ratio=4, + use_pos_emb=True, + conv_bias=True, + qkv_bias=True, + ls_init_value=1e-6, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + drop_path=0., + attn_drop=0., + proj_drop=0. + ): + super().__init__() + width = max(int(math.ceil(dim / num_scales)), int(math.floor(dim // num_scales))) + self.width = width + self.num_scales = max(1, num_scales - 1) + + convs = [] + for i in range(self.num_scales): + convs.append(create_conv2d(width, width, kernel_size=3, depthwise=True, bias=conv_bias)) + self.convs = nn.ModuleList(convs) + + self.pos_embd = None + if use_pos_emb: + self.pos_embd = PositionalEncodingFourier(dim=dim) + self.norm_xca = norm_layer(dim) + self.gamma_xca = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None + self.xca = CrossCovarianceAttn( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop) + + self.norm = norm_layer(dim, eps=1e-6) + self.mlp = Mlp(dim, int(expand_ratio * dim), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + + # scales code re-written for torchscript as per my res2net fixes -rw + spx = torch.split(x, self.width, 1) + spo = [] + sp = spx[0] + for i, conv in enumerate(self.convs): + if i > 0: + sp = sp + spx[i] + sp = conv(sp) + spo.append(sp) + spo.append(spx[-1]) + x = torch.cat(spo, 1) + + # XCA + B, C, H, W = x.shape + x = x.reshape(B, C, H * W).permute(0, 2, 1) + if self.pos_embd is not None: + pos_encoding = self.pos_embd((B, H, W)).reshape(B, -1, x.shape[1]).permute(0, 2, 1) + x = x + pos_encoding + x = x + self.drop_path(self.gamma_xca * self.xca(self.norm_xca(x))) + x = x.reshape(B, H, W, C) + + # Inverted Bottleneck + x = self.norm(x) + x = self.mlp(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = shortcut + self.drop_path(x) + return x + + +class EdgeNeXtStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + stride=2, + depth=2, + num_global_blocks=1, + num_heads=4, + scales=2, + kernel_size=7, + expand_ratio=4, + use_pos_emb=False, + downsample_block=False, + conv_bias=True, + ls_init_value=1.0, + drop_path_rates=None, + norm_layer=LayerNorm2d, + norm_layer_cl=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU + ): + super().__init__() + self.grad_checkpointing = False + + if downsample_block or stride == 1: + self.downsample = nn.Identity() + else: + self.downsample = nn.Sequential( + norm_layer(in_chs), + nn.Conv2d(in_chs, out_chs, kernel_size=2, stride=2, bias=conv_bias) + ) + in_chs = out_chs + + stage_blocks = [] + for i in range(depth): + if i < depth - num_global_blocks: + stage_blocks.append( + ConvBlock( + dim=in_chs, + dim_out=out_chs, + stride=stride if downsample_block and i == 0 else 1, + conv_bias=conv_bias, + kernel_size=kernel_size, + expand_ratio=expand_ratio, + ls_init_value=ls_init_value, + drop_path=drop_path_rates[i], + norm_layer=norm_layer_cl, + act_layer=act_layer, + ) + ) + else: + stage_blocks.append( + SplitTransposeBlock( + dim=in_chs, + num_scales=scales, + num_heads=num_heads, + expand_ratio=expand_ratio, + use_pos_emb=use_pos_emb, + conv_bias=conv_bias, + ls_init_value=ls_init_value, + drop_path=drop_path_rates[i], + norm_layer=norm_layer_cl, + act_layer=act_layer, + ) + ) + in_chs = out_chs + self.blocks = nn.Sequential(*stage_blocks) + + def forward(self, x): + x = self.downsample(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class EdgeNeXt(nn.Module): + def __init__( + self, + in_chans=3, + num_classes=1000, + global_pool='avg', + dims=(24, 48, 88, 168), + depths=(3, 3, 9, 3), + global_block_counts=(0, 1, 1, 1), + kernel_sizes=(3, 5, 7, 9), + heads=(8, 8, 8, 8), + d2_scales=(2, 2, 3, 4), + use_pos_emb=(False, True, False, False), + ls_init_value=1e-6, + head_init_scale=1., + expand_ratio=4, + downsample_block=False, + conv_bias=True, + stem_type='patch', + head_norm_first=False, + act_layer=nn.GELU, + drop_path_rate=0., + drop_rate=0., + ): + super().__init__() + self.num_classes = num_classes + self.global_pool = global_pool + self.drop_rate = drop_rate + norm_layer = partial(LayerNorm2d, eps=1e-6) + norm_layer_cl = partial(nn.LayerNorm, eps=1e-6) + self.feature_info = [] + + assert stem_type in ('patch', 'overlap') + if stem_type == 'patch': + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, bias=conv_bias), + norm_layer(dims[0]), + ) + else: + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=9, stride=4, padding=9 // 2, bias=conv_bias), + norm_layer(dims[0]), + ) + + curr_stride = 4 + stages = [] + dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + in_chs = dims[0] + for i in range(4): + stride = 2 if curr_stride == 2 or i > 0 else 1 + # FIXME support dilation / output_stride + curr_stride *= stride + stages.append(EdgeNeXtStage( + in_chs=in_chs, + out_chs=dims[i], + stride=stride, + depth=depths[i], + num_global_blocks=global_block_counts[i], + num_heads=heads[i], + drop_path_rates=dp_rates[i], + scales=d2_scales[i], + expand_ratio=expand_ratio, + kernel_size=kernel_sizes[i], + use_pos_emb=use_pos_emb[i], + ls_init_value=ls_init_value, + downsample_block=downsample_block, + conv_bias=conv_bias, + norm_layer=norm_layer, + norm_layer_cl=norm_layer_cl, + act_layer=act_layer, + )) + # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 + in_chs = dims[i] + self.feature_info += [dict(num_chs=in_chs, reduction=curr_stride, module=f'stages.{i}')] + + self.stages = nn.Sequential(*stages) + + self.num_features = dims[-1] + self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() + self.head = nn.Sequential(OrderedDict([ + ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), + ('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)), + ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), + ('drop', nn.Dropout(self.drop_rate)), + ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) + + named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+)\.downsample', (0,)), # blocks + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + (r'^norm_pre', (99999,)) + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes=0, global_pool=None): + if global_pool is not None: + self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() + self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm_pre(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( + x = self.head.global_pool(x) + x = self.head.norm(x) + x = self.head.flatten(x) + x = self.head.drop(x) + return x if pre_logits else self.head.fc(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _init_weights(module, name=None, head_init_scale=1.0): + if isinstance(module, nn.Conv2d): + trunc_normal_tf_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Linear): + trunc_normal_tf_(module.weight, std=.02) + nn.init.zeros_(module.bias) + if name and 'head.' in name: + module.weight.data.mul_(head_init_scale) + module.bias.data.mul_(head_init_scale) + + +def checkpoint_filter_fn(state_dict, model): + """ Remap FB checkpoints -> timm """ + if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict: + return state_dict # non-FB checkpoint + + # models were released as train checkpoints... :/ + if 'model_ema' in state_dict: + state_dict = state_dict['model_ema'] + elif 'model' in state_dict: + state_dict = state_dict['model'] + elif 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + + out_dict = {} + import re + for k, v in state_dict.items(): + k = k.replace('downsample_layers.0.', 'stem.') + k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) + k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k) + k = k.replace('dwconv', 'conv_dw') + k = k.replace('pwconv', 'mlp.fc') + k = k.replace('head.', 'head.fc.') + if k.startswith('norm.'): + k = k.replace('norm', 'head.norm') + if v.ndim == 2 and 'head' not in k: + model_shape = model.state_dict()[k].shape + v = v.reshape(model_shape) + out_dict[k] = v + return out_dict + + +def _create_edgenext(variant, pretrained=False, **kwargs): + model = build_model_with_cfg( + EdgeNeXt, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), + **kwargs) + return model + + +@register_model +def edgenext_xx_small(pretrained=False, **kwargs): + # 1.33M & 260.58M @ 256 resolution + # 71.23% Top-1 accuracy + # No AA, Color Jitter=0.4, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=51.66 versus 47.67 for MobileViT_XXS + # For A100: FPS @ BS=1: 212.13 & @ BS=256: 7042.06 versus FPS @ BS=1: 96.68 & @ BS=256: 4624.71 for MobileViT_XXS + model_kwargs = dict(depths=(2, 2, 6, 2), dims=(24, 48, 88, 168), heads=(4, 4, 4, 4), **kwargs) + return _create_edgenext('edgenext_xx_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def edgenext_x_small(pretrained=False, **kwargs): + # 2.34M & 538.0M @ 256 resolution + # 75.00% Top-1 accuracy + # No AA, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=31.61 versus 28.49 for MobileViT_XS + # For A100: FPS @ BS=1: 179.55 & @ BS=256: 4404.95 versus FPS @ BS=1: 94.55 & @ BS=256: 2361.53 for MobileViT_XS + model_kwargs = dict(depths=(3, 3, 9, 3), dims=(32, 64, 100, 192), heads=(4, 4, 4, 4), **kwargs) + return _create_edgenext('edgenext_x_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def edgenext_small(pretrained=False, **kwargs): + # 5.59M & 1260.59M @ 256 resolution + # 79.43% Top-1 accuracy + # AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=20.47 versus 18.86 for MobileViT_S + # For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S + model_kwargs = dict(depths=(3, 3, 9, 3), dims=(48, 96, 160, 304), **kwargs) + return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def edgenext_small_rw(pretrained=False, **kwargs): + # 5.59M & 1260.59M @ 256 resolution + # 79.43% Top-1 accuracy + # AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=20.47 versus 18.86 for MobileViT_S + # For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S + model_kwargs = dict( + depths=(3, 3, 9, 3), dims=(48, 96, 192, 384), + downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs) + return _create_edgenext('edgenext_small_rw', pretrained=pretrained, **model_kwargs) + diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 1276b68e..fda84171 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -455,18 +455,26 @@ def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter): filter_kwargs(kwargs, names=kwargs_filter) -def resolve_pretrained_cfg(variant: str, pretrained_cfg=None, kwargs=None): +def resolve_pretrained_cfg(variant: str, pretrained_cfg=None): if pretrained_cfg and isinstance(pretrained_cfg, dict): - # highest priority, pretrained_cfg available and passed explicitly + # highest priority, pretrained_cfg available and passed as arg return deepcopy(pretrained_cfg) - if kwargs and 'pretrained_cfg' in kwargs: - # next highest, pretrained_cfg in a kwargs dict, pop and return - pretrained_cfg = kwargs.pop('pretrained_cfg', {}) - if pretrained_cfg: - return deepcopy(pretrained_cfg) - # lookup pretrained cfg in model registry by variant + # fallback to looking up pretrained cfg in model registry by variant identifier pretrained_cfg = get_pretrained_cfg(variant) - assert pretrained_cfg + if not pretrained_cfg: + _logger.warning( + f"No pretrained configuration specified for {variant} model. Using a default." + f" Please add a config to the model pretrained_cfg registry or pass explicitly.") + pretrained_cfg = dict( + url='', + num_classes=1000, + input_size=(3, 224, 224), + pool_size=None, + crop_pct=.9, + interpolation='bicubic', + first_conv='', + classifier='', + ) return pretrained_cfg diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index e34de657..2c6e7eb7 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -428,7 +428,7 @@ class InceptionV3Aux(InceptionV3): def _create_inception_v3(variant, pretrained=False, **kwargs): - pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) + pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) aux_logits = kwargs.pop('aux_logits', False) if aux_logits: assert not kwargs.pop('features_only', False) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 7e9e7b19..b9eeec0f 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -25,8 +25,8 @@ from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn -from .norm import GroupNorm, LayerNorm2d -from .norm_act import BatchNormAct2d, GroupNormAct +from .norm import GroupNorm, GroupNorm1, LayerNorm2d +from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm from .padding import get_padding, get_same_padding, pad_same from .patch_embed import PatchEmbed from .pool2d_same import AvgPool2dSame, create_pool2d @@ -39,4 +39,4 @@ from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .trace_utils import _assert, _float_to_int -from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_ +from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ diff --git a/timm/models/layers/conv_bn_act.py b/timm/models/layers/conv_bn_act.py index af010573..9e7c64b8 100644 --- a/timm/models/layers/conv_bn_act.py +++ b/timm/models/layers/conv_bn_act.py @@ -2,6 +2,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ +import functools from torch import nn as nn from .create_conv2d import create_conv2d @@ -40,12 +41,26 @@ class ConvNormAct(nn.Module): ConvBnAct = ConvNormAct +def create_aa(aa_layer, channels, stride=2, enable=True): + if not aa_layer or not enable: + return nn.Identity() + if isinstance(aa_layer, functools.partial): + if issubclass(aa_layer.func, nn.AvgPool2d): + return aa_layer() + else: + return aa_layer(channels) + elif issubclass(aa_layer, nn.AvgPool2d): + return aa_layer(stride) + else: + return aa_layer(channels=channels, stride=stride) + + class ConvNormActAa(nn.Module): def __init__( self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None): super(ConvNormActAa, self).__init__() - use_aa = aa_layer is not None + use_aa = aa_layer is not None and stride == 2 self.conv = create_conv2d( in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, @@ -56,7 +71,7 @@ class ConvNormActAa(nn.Module): # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) - self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else nn.Identity() + self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) @property def in_channels(self): diff --git a/timm/models/layers/create_attn.py b/timm/models/layers/create_attn.py index 028c0f75..cc7e91ea 100644 --- a/timm/models/layers/create_attn.py +++ b/timm/models/layers/create_attn.py @@ -22,7 +22,7 @@ def get_attn(attn_type): if isinstance(attn_type, torch.nn.Module): return attn_type module_cls = None - if attn_type is not None: + if attn_type: if isinstance(attn_type, str): attn_type = attn_type.lower() # Lightweight attention modules (channel and/or coarse spatial). diff --git a/timm/models/layers/drop.py b/timm/models/layers/drop.py index ae065277..1ab1c8f5 100644 --- a/timm/models/layers/drop.py +++ b/timm/models/layers/drop.py @@ -164,3 +164,6 @@ class DropPath(nn.Module): def forward(self, x): return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob,3):0.3f}' diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index b643302c..ea776207 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -256,8 +256,9 @@ class EvoNorm2dS0a(EvoNorm2dS0): class EvoNorm2dS1(nn.Module): def __init__( self, num_features, groups=32, group_size=None, - apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): + apply_act=True, act_layer=None, eps=1e-5, **_): super().__init__() + act_layer = act_layer or nn.SiLU self.apply_act = apply_act # apply activation (non-linearity) if act_layer is not None and apply_act: self.act = create_act_layer(act_layer) @@ -290,7 +291,7 @@ class EvoNorm2dS1(nn.Module): class EvoNorm2dS1a(EvoNorm2dS1): def __init__( self, num_features, groups=32, group_size=None, - apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_): + apply_act=True, act_layer=None, eps=1e-3, **_): super().__init__( num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) @@ -305,8 +306,9 @@ class EvoNorm2dS1a(EvoNorm2dS1): class EvoNorm2dS2(nn.Module): def __init__( self, num_features, groups=32, group_size=None, - apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): + apply_act=True, act_layer=None, eps=1e-5, **_): super().__init__() + act_layer = act_layer or nn.SiLU self.apply_act = apply_act # apply activation (non-linearity) if act_layer is not None and apply_act: self.act = create_act_layer(act_layer) @@ -338,7 +340,7 @@ class EvoNorm2dS2(nn.Module): class EvoNorm2dS2a(EvoNorm2dS2): def __init__( self, num_features, groups=32, group_size=None, - apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_): + apply_act=True, act_layer=None, eps=1e-3, **_): super().__init__( num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index 85297420..1677dbfa 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -14,11 +14,59 @@ class GroupNorm(nn.GroupNorm): return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) +class GroupNorm1(nn.GroupNorm): + """ Group Normalization with 1 group. + Input: tensor in shape [B, C, *] + """ + + def __init__(self, num_channels, **kwargs): + super().__init__(1, num_channels, **kwargs) + + class LayerNorm2d(nn.LayerNorm): - """ LayerNorm for channels of '2D' spatial BCHW tensors """ - def __init__(self, num_channels): - super().__init__(num_channels) + """ LayerNorm for channels of '2D' spatial NCHW tensors """ + def __init__(self, num_channels, eps=1e-6, affine=True): + super().__init__(num_channels, eps=eps, elementwise_affine=affine) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.layer_norm( x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + + +def _is_contiguous(tensor: torch.Tensor) -> bool: + # jit is oh so lovely :/ + # if torch.jit.is_tracing(): + # return True + if torch.jit.is_scripting(): + return tensor.is_contiguous() + else: + return tensor.is_contiguous(memory_format=torch.contiguous_format) + + +@torch.jit.script +def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): + s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) + x = (x - u) * torch.rsqrt(s + eps) + x = x * weight[:, None, None] + bias[:, None, None] + return x + + +class LayerNormExp2d(nn.LayerNorm): + """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). + + Experimental implementation w/ manual norm for tensors non-contiguous tensors. + + This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last + layout. However, benefits are not always clear and can perform worse on other GPUs. + """ + + def __init__(self, num_channels, eps=1e-6): + super().__init__(num_channels, eps=eps) + + def forward(self, x) -> torch.Tensor: + if _is_contiguous(x): + x = F.layer_norm( + x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + else: + x = _layer_norm_cf(x, self.weight, self.bias, self.eps) + return x diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index 0f386260..6bfdf201 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -1,6 +1,6 @@ """ Normalization + Activation Layers """ -from typing import Union, List +from typing import Union, List, Optional, Any import torch from torch import nn as nn @@ -18,10 +18,29 @@ class BatchNormAct2d(nn.BatchNorm2d): instead of composing it as a .bn member. """ def __init__( - self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, - apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): - super(BatchNormAct2d, self).__init__( - num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + apply_act=True, + act_layer=nn.ReLU, + inplace=True, + drop_layer=None, + device=None, + dtype=None + ): + try: + factory_kwargs = {'device': device, 'dtype': dtype} + super(BatchNormAct2d, self).__init__( + num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats, + **factory_kwargs + ) + except TypeError: + # NOTE for backwards compat with old PyTorch w/o factory device/dtype support + super(BatchNormAct2d, self).__init__( + num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) self.drop = drop_layer() if drop_layer is not None else nn.Identity() act_layer = get_act_layer(act_layer) # string -> nn.Module if act_layer is not None and apply_act: @@ -81,6 +100,62 @@ class BatchNormAct2d(nn.BatchNorm2d): return x +class SyncBatchNormAct(nn.SyncBatchNorm): + # Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254) + # This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers + # but ONLY when used in conjunction with the timm conversion function below. + # Do not create this module directly or use the PyTorch conversion function. + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = super().forward(x) # SyncBN doesn't work with torchscript anyways, so this is fine + if hasattr(self, "drop"): + x = self.drop(x) + if hasattr(self, "act"): + x = self.act(x) + return x + + +def convert_sync_batchnorm(module, process_group=None): + # convert both BatchNorm and BatchNormAct layers to Synchronized variants + module_output = module + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + if isinstance(module, BatchNormAct2d): + # convert timm norm + act layer + module_output = SyncBatchNormAct( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + process_group=process_group, + ) + # set act and drop attr from the original module + module_output.act = module.act + module_output.drop = module.drop + else: + # convert standard BatchNorm layers + module_output = torch.nn.SyncBatchNorm( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + process_group, + ) + if module.affine: + with torch.no_grad(): + module_output.weight = module.weight + module_output.bias = module.bias + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + if hasattr(module, "qconfig"): + module_output.qconfig = module.qconfig + for name, child in module.named_children(): + module_output.add_module(name, convert_sync_batchnorm(child, process_group)) + del module + return module_output + + def group_norm_tpu(x, w, b, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False, flatten: bool = False): # This is a workaround for some odd behaviour running on PyTorch XLA w/ TPUs. x_shape = x.shape diff --git a/timm/models/layers/test_time_pool.py b/timm/models/layers/test_time_pool.py index 98c0bf53..5826d8c9 100644 --- a/timm/models/layers/test_time_pool.py +++ b/timm/models/layers/test_time_pool.py @@ -36,7 +36,7 @@ class TestTimePoolHead(nn.Module): return x.view(x.size(0), -1) -def apply_test_time_pool(model, config, use_test_size=True): +def apply_test_time_pool(model, config, use_test_size=False): test_time_pool = False if not hasattr(model, 'default_cfg') or not model.default_cfg: return model, False diff --git a/timm/models/layers/weight_init.py b/timm/models/layers/weight_init.py index 305a2fd0..4a160931 100644 --- a/timm/models/layers/weight_init.py +++ b/timm/models/layers/weight_init.py @@ -49,6 +49,11 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. + + NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are + applied while sampling the normal with mean/std applied, therefore a, b args + should be adjusted to match the range of mean, std args. + Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution @@ -62,6 +67,35 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): return _no_grad_trunc_normal_(tensor, mean, std, a, b) +def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsquently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + _no_grad_trunc_normal_(tensor, 0, 1.0, a, b) + with torch.no_grad(): + tensor.mul_(std).add_(mean) + return tensor + + def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == 'fan_in': @@ -75,7 +109,7 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) - trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978) elif distribution == "normal": tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": diff --git a/timm/models/mobilevit.py b/timm/models/mobilevit.py index 1c55bd1c..2a3ab924 100644 --- a/timm/models/mobilevit.py +++ b/timm/models/mobilevit.py @@ -1,7 +1,8 @@ """ MobileViT Paper: -`MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178 +V1: `MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178 +V2: `Separable Self-attention for Mobile Vision Transformers` - https://arxiv.org/abs/2206.02680 MobileVitBlock and checkpoints adapted from https://github.com/apple/ml-cvnets (original copyright below) License: https://github.com/apple/ml-cvnets/blob/main/LICENSE (Apple open source) @@ -13,7 +14,7 @@ Rest of code, ByobNet, and Transformer block hacked together by / Copyright 2022 # Copyright (C) 2020 Apple Inc. All Rights Reserved. # import math -from typing import Union, Callable, Dict, Tuple, Optional +from typing import Union, Callable, Dict, Tuple, Optional, Sequence import torch from torch import nn @@ -21,7 +22,7 @@ import torch.nn.functional as F from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups from .fx_features import register_notrace_module -from .layers import to_2tuple, make_divisible +from .layers import to_2tuple, make_divisible, LayerNorm2d, GroupNorm1, ConvMlp, DropPath from .vision_transformer import Block as TransformerBlock from .helpers import build_model_with_cfg from .registry import register_model @@ -48,6 +49,48 @@ default_cfgs = { 'mobilevit_s': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_s-38a5a959.pth'), 'semobilevit_s': _cfg(), + + 'mobilevitv2_050': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_050-49951ee2.pth', + crop_pct=0.888), + 'mobilevitv2_075': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_075-b5556ef6.pth', + crop_pct=0.888), + 'mobilevitv2_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_100-e464ef3b.pth', + crop_pct=0.888), + 'mobilevitv2_125': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_125-0ae35027.pth', + crop_pct=0.888), + 'mobilevitv2_150': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150-737c5019.pth', + crop_pct=0.888), + 'mobilevitv2_175': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175-16462ee2.pth', + crop_pct=0.888), + 'mobilevitv2_200': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200-b3422f67.pth', + crop_pct=0.888), + + 'mobilevitv2_150_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_in22ft1k-0b555d7b.pth', + crop_pct=0.888), + 'mobilevitv2_175_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_in22ft1k-4117fa1f.pth', + crop_pct=0.888), + 'mobilevitv2_200_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_in22ft1k-1d7c8927.pth', + crop_pct=0.888), + + 'mobilevitv2_150_384_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_384_in22ft1k-9e142854.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'mobilevitv2_175_384_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_384_in22ft1k-059cbe56.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'mobilevitv2_200_384_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_384_in22ft1k-32c87503.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), } @@ -72,6 +115,40 @@ def _mobilevit_block(d, c, s, transformer_dim, transformer_depth, patch_size=4, ) +def _mobilevitv2_block(d, c, s, transformer_depth, patch_size=2, br=2.0, transformer_br=0.5): + # inverted residual + mobilevit blocks as per MobileViT network + return ( + _inverted_residual_block(d=d, c=c, s=s, br=br), + ByoBlockCfg( + type='mobilevit2', d=1, c=c, s=1, br=transformer_br, gs=1, + block_kwargs=dict( + transformer_depth=transformer_depth, + patch_size=patch_size) + ) + ) + + +def _mobilevitv2_cfg(multiplier=1.0): + chs = (64, 128, 256, 384, 512) + if multiplier != 1.0: + chs = tuple([int(c * multiplier) for c in chs]) + cfg = ByoModelCfg( + blocks=( + _inverted_residual_block(d=1, c=chs[0], s=1, br=2.0), + _inverted_residual_block(d=2, c=chs[1], s=2, br=2.0), + _mobilevitv2_block(d=1, c=chs[2], s=2, transformer_depth=2), + _mobilevitv2_block(d=1, c=chs[3], s=2, transformer_depth=4), + _mobilevitv2_block(d=1, c=chs[4], s=2, transformer_depth=3), + ), + stem_chs=int(32 * multiplier), + stem_type='3x3', + stem_pool='', + downsample='', + act_layer='silu', + ) + return cfg + + model_cfgs = dict( mobilevit_xxs=ByoModelCfg( blocks=( @@ -137,11 +214,19 @@ model_cfgs = dict( attn_kwargs=dict(rd_ratio=1/8), num_features=640, ), + + mobilevitv2_050=_mobilevitv2_cfg(.50), + mobilevitv2_075=_mobilevitv2_cfg(.75), + mobilevitv2_125=_mobilevitv2_cfg(1.25), + mobilevitv2_100=_mobilevitv2_cfg(1.0), + mobilevitv2_150=_mobilevitv2_cfg(1.5), + mobilevitv2_175=_mobilevitv2_cfg(1.75), + mobilevitv2_200=_mobilevitv2_cfg(2.0), ) @register_notrace_module -class MobileViTBlock(nn.Module): +class MobileVitBlock(nn.Module): """ MobileViT block Paper: https://arxiv.org/abs/2110.02178?context=cs.LG """ @@ -165,9 +250,9 @@ class MobileViTBlock(nn.Module): drop_path_rate: float = 0., layers: LayerFn = None, transformer_norm_layer: Callable = nn.LayerNorm, - downsample: str = '' + **kwargs, # eat unused args ): - super(MobileViTBlock, self).__init__() + super(MobileVitBlock, self).__init__() layers = layers or LayerFn() groups = num_groups(group_size, in_chs) @@ -241,7 +326,270 @@ class MobileViTBlock(nn.Module): return x -register_block('mobilevit', MobileViTBlock) +class LinearSelfAttention(nn.Module): + """ + This layer applies a self-attention with linear complexity, as described in `https://arxiv.org/abs/2206.02680` + This layer can be used for self- as well as cross-attention. + Args: + embed_dim (int): :math:`C` from an expected input of size :math:`(N, C, H, W)` + attn_drop (float): Dropout value for context scores. Default: 0.0 + bias (bool): Use bias in learnable layers. Default: True + Shape: + - Input: :math:`(N, C, P, N)` where :math:`N` is the batch size, :math:`C` is the input channels, + :math:`P` is the number of pixels in the patch, and :math:`N` is the number of patches + - Output: same as the input + .. note:: + For MobileViTv2, we unfold the feature map [B, C, H, W] into [B, C, P, N] where P is the number of pixels + in a patch and N is the number of patches. Because channel is the first dimension in this unfolded tensor, + we use point-wise convolution (instead of a linear layer). This avoids a transpose operation (which may be + expensive on resource-constrained devices) that may be required to convert the unfolded tensor from + channel-first to channel-last format in case of a linear layer. + """ + + def __init__( + self, + embed_dim: int, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + + self.qkv_proj = nn.Conv2d( + in_channels=embed_dim, + out_channels=1 + (2 * embed_dim), + bias=bias, + kernel_size=1, + ) + self.attn_drop = nn.Dropout(attn_drop) + self.out_proj = nn.Conv2d( + in_channels=embed_dim, + out_channels=embed_dim, + bias=bias, + kernel_size=1, + ) + self.out_drop = nn.Dropout(proj_drop) + + def _forward_self_attn(self, x: torch.Tensor) -> torch.Tensor: + # [B, C, P, N] --> [B, h + 2d, P, N] + qkv = self.qkv_proj(x) + + # Project x into query, key and value + # Query --> [B, 1, P, N] + # value, key --> [B, d, P, N] + query, key, value = qkv.split([1, self.embed_dim, self.embed_dim], dim=1) + + # apply softmax along N dimension + context_scores = F.softmax(query, dim=-1) + context_scores = self.attn_drop(context_scores) + + # Compute context vector + # [B, d, P, N] x [B, 1, P, N] -> [B, d, P, N] --> [B, d, P, 1] + context_vector = (key * context_scores).sum(dim=-1, keepdim=True) + + # combine context vector with values + # [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N] + out = F.relu(value) * context_vector.expand_as(value) + out = self.out_proj(out) + out = self.out_drop(out) + return out + + @torch.jit.ignore() + def _forward_cross_attn(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor: + # x --> [B, C, P, N] + # x_prev = [B, C, P, M] + batch_size, in_dim, kv_patch_area, kv_num_patches = x.shape + q_patch_area, q_num_patches = x.shape[-2:] + + assert ( + kv_patch_area == q_patch_area + ), "The number of pixels in a patch for query and key_value should be the same" + + # compute query, key, and value + # [B, C, P, M] --> [B, 1 + d, P, M] + qk = F.conv2d( + x_prev, + weight=self.qkv_proj.weight[:self.embed_dim + 1], + bias=self.qkv_proj.bias[:self.embed_dim + 1], + ) + + # [B, 1 + d, P, M] --> [B, 1, P, M], [B, d, P, M] + query, key = qk.split([1, self.embed_dim], dim=1) + # [B, C, P, N] --> [B, d, P, N] + value = F.conv2d( + x, + weight=self.qkv_proj.weight[self.embed_dim + 1], + bias=self.qkv_proj.bias[self.embed_dim + 1] if self.qkv_proj.bias is not None else None, + ) + + # apply softmax along M dimension + context_scores = F.softmax(query, dim=-1) + context_scores = self.attn_drop(context_scores) + + # compute context vector + # [B, d, P, M] * [B, 1, P, M] -> [B, d, P, M] --> [B, d, P, 1] + context_vector = (key * context_scores).sum(dim=-1, keepdim=True) + + # combine context vector with values + # [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N] + out = F.relu(value) * context_vector.expand_as(value) + out = self.out_proj(out) + out = self.out_drop(out) + return out + + def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor: + if x_prev is None: + return self._forward_self_attn(x) + else: + return self._forward_cross_attn(x, x_prev=x_prev) + + +class LinearTransformerBlock(nn.Module): + """ + This class defines the pre-norm transformer encoder with linear self-attention in `MobileViTv2 paper <>`_ + Args: + embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, P, N)` + mlp_ratio (float): Inner dimension ratio of the FFN relative to embed_dim + drop (float): Dropout rate. Default: 0.0 + attn_drop (float): Dropout rate for attention in multi-head attention. Default: 0.0 + drop_path (float): Stochastic depth rate Default: 0.0 + norm_layer (Callable): Normalization layer. Default: layer_norm_2d + Shape: + - Input: :math:`(B, C_{in}, P, N)` where :math:`B` is batch size, :math:`C_{in}` is input embedding dim, + :math:`P` is number of pixels in a patch, and :math:`N` is number of patches, + - Output: same shape as the input + """ + + def __init__( + self, + embed_dim: int, + mlp_ratio: float = 2.0, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + act_layer=None, + norm_layer=None, + ) -> None: + super().__init__() + act_layer = act_layer or nn.SiLU + norm_layer = norm_layer or GroupNorm1 + + self.norm1 = norm_layer(embed_dim) + self.attn = LinearSelfAttention(embed_dim=embed_dim, attn_drop=attn_drop, proj_drop=drop) + self.drop_path1 = DropPath(drop_path) + + self.norm2 = norm_layer(embed_dim) + self.mlp = ConvMlp( + in_features=embed_dim, + hidden_features=int(embed_dim * mlp_ratio), + act_layer=act_layer, + drop=drop) + self.drop_path2 = DropPath(drop_path) + + def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor: + if x_prev is None: + # self-attention + x = x + self.drop_path1(self.attn(self.norm1(x))) + else: + # cross-attention + res = x + x = self.norm1(x) # norm + x = self.attn(x, x_prev) # attn + x = self.drop_path1(x) + res # residual + + # Feed forward network + x = x + self.drop_path2(self.mlp(self.norm2(x))) + return x + + +@register_notrace_module +class MobileVitV2Block(nn.Module): + """ + This class defines the `MobileViTv2 block <>`_ + """ + + def __init__( + self, + in_chs: int, + out_chs: Optional[int] = None, + kernel_size: int = 3, + bottle_ratio: float = 1.0, + group_size: Optional[int] = 1, + dilation: Tuple[int, int] = (1, 1), + mlp_ratio: float = 2.0, + transformer_dim: Optional[int] = None, + transformer_depth: int = 2, + patch_size: int = 8, + attn_drop: float = 0., + drop: int = 0., + drop_path_rate: float = 0., + layers: LayerFn = None, + transformer_norm_layer: Callable = GroupNorm1, + **kwargs, # eat unused args + ): + super(MobileVitV2Block, self).__init__() + layers = layers or LayerFn() + groups = num_groups(group_size, in_chs) + out_chs = out_chs or in_chs + transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs) + + self.conv_kxk = layers.conv_norm_act( + in_chs, in_chs, kernel_size=kernel_size, + stride=1, groups=groups, dilation=dilation[0]) + self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False) + + self.transformer = nn.Sequential(*[ + LinearTransformerBlock( + transformer_dim, + mlp_ratio=mlp_ratio, + attn_drop=attn_drop, + drop=drop, + drop_path=drop_path_rate, + act_layer=layers.act, + norm_layer=transformer_norm_layer + ) + for _ in range(transformer_depth) + ]) + self.norm = transformer_norm_layer(transformer_dim) + + self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1, apply_act=False) + + self.patch_size = to_2tuple(patch_size) + self.patch_area = self.patch_size[0] * self.patch_size[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, H, W = x.shape + patch_h, patch_w = self.patch_size + new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w + num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w + num_patches = num_patch_h * num_patch_w # N + if new_h != H or new_w != W: + x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=True) + + # Local representation + x = self.conv_kxk(x) + x = self.conv_1x1(x) + + # Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N] + C = x.shape[1] + x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4) + x = x.reshape(B, C, -1, num_patches) + + # Global representations + x = self.transformer(x) + x = self.norm(x) + + # Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W] + x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3) + x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w) + + x = self.conv_proj(x) + return x + + +register_block('mobilevit', MobileVitBlock) +register_block('mobilevit2', MobileVitV2Block) def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs): @@ -252,6 +600,14 @@ def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs): **kwargs) +def _create_mobilevit2(variant, cfg_variant=None, pretrained=False, **kwargs): + return build_model_with_cfg( + ByobNet, variant, pretrained, + model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], + feature_cfg=dict(flatten_sequential=True), + **kwargs) + + @register_model def mobilevit_xxs(pretrained=False, **kwargs): return _create_mobilevit('mobilevit_xxs', pretrained=pretrained, **kwargs) @@ -269,4 +625,75 @@ def mobilevit_s(pretrained=False, **kwargs): @register_model def semobilevit_s(pretrained=False, **kwargs): - return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs) \ No newline at end of file + return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_050(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_050', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_075(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_075', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_100(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_100', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_125(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_125', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_150(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_150', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_175(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_175', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_200(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_200', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_150_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_150_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_175_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_175_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_200_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_200_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_150_384_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_150_384_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_175_384_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_175_384_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_200_384_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_200_384_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs) \ No newline at end of file diff --git a/timm/models/poolformer.py b/timm/models/poolformer.py index 17d657b0..a95195b4 100644 --- a/timm/models/poolformer.py +++ b/timm/models/poolformer.py @@ -26,7 +26,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp +from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp, GroupNorm1 from .registry import register_model @@ -80,15 +80,6 @@ class PatchEmbed(nn.Module): return x -class GroupNorm1(nn.GroupNorm): - """ Group Normalization with 1 group. - Input: tensor in shape [B, C, H, W] - """ - - def __init__(self, num_channels, **kwargs): - super().__init__(1, num_channels, **kwargs) - - class Pooling(nn.Module): def __init__(self, pool_size=3): super().__init__() diff --git a/timm/models/resnet.py b/timm/models/resnet.py index a7f0c0f6..e5a6b791 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -35,6 +35,16 @@ def _cfg(url='', **kwargs): default_cfgs = { # ResNet and Wide ResNet + 'resnet10t': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet10t_176_c3-f3215ab1.pth', + input_size=(3, 176, 176), pool_size=(6, 6), + test_crop_pct=0.95, test_input_size=(3, 224, 224), + first_conv='conv1.0'), + 'resnet14t': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet14t_176_c3-c4ed2c37.pth', + input_size=(3, 176, 176), pool_size=(6, 6), + test_crop_pct=0.95, test_input_size=(3, 224, 224), + first_conv='conv1.0'), 'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'), 'resnet18d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet18d_ra2-48a79e06.pth', @@ -262,6 +272,9 @@ default_cfgs = { 'resnetblur101d': _cfg( url='', interpolation='bicubic', first_conv='conv1.0'), + 'resnetaa50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetaa50_a1h-4cf422b3.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0, interpolation='bicubic'), 'resnetaa50d': _cfg( url='', interpolation='bicubic', first_conv='conv1.0'), @@ -723,6 +736,24 @@ def _create_resnet(variant, pretrained=False, **kwargs): return build_model_with_cfg(ResNet, variant, pretrained, **kwargs) +@register_model +def resnet10t(pretrained=False, **kwargs): + """Constructs a ResNet-10-T model. + """ + model_args = dict( + block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) + return _create_resnet('resnet10t', pretrained, **model_args) + + +@register_model +def resnet14t(pretrained=False, **kwargs): + """Constructs a ResNet-14-T model. + """ + model_args = dict( + block=Bottleneck, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) + return _create_resnet('resnet14t', pretrained, **model_args) + + @register_model def resnet18(pretrained=False, **kwargs): """Constructs a ResNet-18 model. @@ -1436,6 +1467,14 @@ def resnetblur101d(pretrained=False, **kwargs): return _create_resnet('resnetblur101d', pretrained, **model_args) +@register_model +def resnetaa50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model with avgpool anti-aliasing + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, **kwargs) + return _create_resnet('resnetaa50', pretrained, **model_args) + + @register_model def resnetaa50d(pretrained=False, **kwargs): """Constructs a ResNet-50-D model with avgpool anti-aliasing diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 59fd7849..c92c22a3 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -325,8 +325,8 @@ class VisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, - class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', - embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): + class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): """ Args: img_size (int, tuple): input image size @@ -360,15 +360,17 @@ class VisionTransformer(nn.Module): self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 1 if class_token else 0 + self.num_prefix_tokens = 1 if class_token else 0 + self.no_embed_class = no_embed_class self.grad_checkpointing = False self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None - self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule @@ -428,11 +430,24 @@ class VisionTransformer(nn.Module): self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + def _pos_embed(self, x): + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + self.pos_embed + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.pos_embed + return self.pos_drop(x) + def forward_features(self, x): x = self.patch_embed(x) - if self.cls_token is not None: - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - x = self.pos_drop(x + self.pos_embed) + x = self._pos_embed(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: @@ -442,7 +457,7 @@ class VisionTransformer(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool: - x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) return x if pre_logits else self.head(x) @@ -556,7 +571,11 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) if pos_embed_w.shape != model.pos_embed.shape: pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights - pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + pos_embed_w, + model.pos_embed, + getattr(model, 'num_prefix_tokens', 1), + model.patch_embed.grid_size + ) model.pos_embed.copy_(pos_embed_w) model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) @@ -585,16 +604,16 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) -def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): +def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) ntok_new = posemb_new.shape[1] - if num_tokens: - posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] - ntok_new -= num_tokens + if num_prefix_tokens: + posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:] + ntok_new -= num_prefix_tokens else: - posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] gs_old = int(math.sqrt(len(posemb_grid))) if not len(gs_new): # backwards compatibility gs_new = [int(math.sqrt(ntok_new))] * 2 @@ -603,25 +622,34 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) - posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) return posemb -def checkpoint_filter_fn(state_dict, model): +def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False): """ convert patch embedding weight from manual patchify + linear proj to conv""" + import re out_dict = {} if 'model' in state_dict: # For deit models state_dict = state_dict['model'] + for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k and len(v.shape) < 4: # For old models that I trained prior to conv based patchification O, I, H, W = model.patch_embed.proj.weight.shape v = v.reshape(O, -1, H, W) - elif k == 'pos_embed' and v.shape != model.pos_embed.shape: + elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: # To resize pos embedding when using model at different size from pretrained weights v = resize_pos_embed( - v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + v, + model.pos_embed, + getattr(model, 'num_prefix_tokens', 1), + model.patch_embed.grid_size + ) + elif adapt_layer_scale and 'gamma_' in k: + # remap layer-scale gamma into sub-module (deit3 models) + k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k) elif 'pre_logits' in k: # NOTE representation layer removed as not used in latest 21k/1k pretrained weights continue @@ -633,7 +661,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') - pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) + pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) model = build_model_with_cfg( VisionTransformer, variant, pretrained, pretrained_cfg=pretrained_cfg, diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 0c2ac376..52b3ce45 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -8,6 +8,7 @@ import math import logging from functools import partial from collections import OrderedDict +from dataclasses import dataclass from typing import Optional, Tuple import torch @@ -47,9 +48,16 @@ default_cfgs = { 'vit_relpos_base_patch16_224': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'), + 'vit_srelpos_small_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_small_patch16_224-sw-6cdb8849.pth'), + 'vit_srelpos_medium_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_medium_patch16_224-sw-ad702b8c.pth'), + + 'vit_relpos_medium_patch16_cls_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_cls_224-sw-cfe8e259.pth'), 'vit_relpos_base_patch16_cls_224': _cfg( url=''), - 'vit_relpos_base_patch16_gapcls_224': _cfg( + 'vit_relpos_base_patch16_clsgap_224': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'), 'vit_relpos_small_patch16_rpn_224': _cfg(url=''), @@ -59,35 +67,43 @@ default_cfgs = { } -def gen_relative_position_index(win_size: Tuple[int, int], class_token: int = 0) -> torch.Tensor: - # cut and paste w/ modifications from swin / beit codebase - # cls to token & token 2 cls & cls to cls +def gen_relative_position_index( + q_size: Tuple[int, int], + k_size: Tuple[int, int] = None, + class_token: bool = False) -> torch.Tensor: + # Adapted with significant modifications from Swin / BeiT codebases # get pair-wise relative position index for each token inside the window - window_area = win_size[0] * win_size[1] - coords = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1) # 2, Wh, Ww - relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += win_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += win_size[1] - 1 - relative_coords[:, :, 0] *= 2 * win_size[1] - 1 + q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww + if k_size is None: + k_coords = q_coords + k_size = q_size + else: + # different q vs k sizes is a WIP + k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1) + relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2 + _, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0) + if class_token: - num_relative_distance = (2 * win_size[0] - 1) * (2 * win_size[1] - 1) + 3 - relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype) - relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + # handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias + # NOTE not intended or tested with MLP log-coords + max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1])) + num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3 + relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0]) relative_position_index[0, 0:] = num_relative_distance - 3 relative_position_index[0:, 0] = num_relative_distance - 2 relative_position_index[0, 0] = num_relative_distance - 1 - else: - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - return relative_position_index + + return relative_position_index.contiguous() def gen_relative_log_coords( win_size: Tuple[int, int], pretrained_win_size: Tuple[int, int] = (0, 0), - mode='swin' + mode='swin', ): - # as per official swin-v2 impl, supporting timm swin-v2-cr coords as well + assert mode in ('swin', 'cr', 'rw') + # as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32) relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32) relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) @@ -100,12 +116,22 @@ def gen_relative_log_coords( relative_coords_table[:, :, 0] /= (win_size[0] - 1) relative_coords_table[:, :, 1] /= (win_size[1] - 1) relative_coords_table *= 8 # normalize to -8, 8 - scale = math.log2(8) + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + 1.0 + relative_coords_table.abs()) / math.log2(8) else: - # FIXME we should support a form of normalization (to -1/1) for this mode? - scale = math.log2(math.e) - relative_coords_table = torch.sign(relative_coords_table) * torch.log2( - 1.0 + relative_coords_table.abs()) / scale + if mode == 'rw': + # cr w/ window size normalization -> [-1,1] log coords + relative_coords_table[:, :, 0] /= (win_size[0] - 1) + relative_coords_table[:, :, 1] /= (win_size[1] - 1) + relative_coords_table *= 8 # scale to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + 1.0 + relative_coords_table.abs()) + relative_coords_table /= math.log2(9) # -> [-1, 1] + else: + # mode == 'cr' + relative_coords_table = torch.sign(relative_coords_table) * torch.log( + 1.0 + relative_coords_table.abs()) + return relative_coords_table @@ -115,19 +141,29 @@ class RelPosMlp(nn.Module): window_size, num_heads=8, hidden_dim=128, - class_token=False, + prefix_tokens=0, mode='cr', pretrained_window_size=(0, 0) ): super().__init__() self.window_size = window_size self.window_area = self.window_size[0] * self.window_size[1] - self.class_token = 1 if class_token else 0 + self.prefix_tokens = prefix_tokens self.num_heads = num_heads self.bias_shape = (self.window_area,) * 2 + (num_heads,) - self.apply_sigmoid = mode == 'swin' + if mode == 'swin': + self.bias_act = nn.Sigmoid() + self.bias_gain = 16 + mlp_bias = (True, False) + elif mode == 'rw': + self.bias_act = nn.Tanh() + self.bias_gain = 4 + mlp_bias = True + else: + self.bias_act = nn.Identity() + self.bias_gain = None + mlp_bias = True - mlp_bias = (True, False) if mode == 'swin' else True self.mlp = Mlp( 2, # x, y hidden_features=hidden_dim, @@ -155,10 +191,11 @@ class RelPosMlp(nn.Module): self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.view(self.bias_shape) relative_position_bias = relative_position_bias.permute(2, 0, 1) - if self.apply_sigmoid: - relative_position_bias = 16 * torch.sigmoid(relative_position_bias) - if self.class_token: - relative_position_bias = F.pad(relative_position_bias, [self.class_token, 0, self.class_token, 0]) + relative_position_bias = self.bias_act(relative_position_bias) + if self.bias_gain is not None: + relative_position_bias = self.bias_gain * relative_position_bias + if self.prefix_tokens: + relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0]) return relative_position_bias.unsqueeze(0).contiguous() def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): @@ -167,18 +204,18 @@ class RelPosMlp(nn.Module): class RelPosBias(nn.Module): - def __init__(self, window_size, num_heads, class_token=False): + def __init__(self, window_size, num_heads, prefix_tokens=0): super().__init__() + assert prefix_tokens <= 1 self.window_size = window_size self.window_area = window_size[0] * window_size[1] - self.class_token = 1 if class_token else 0 - self.bias_shape = (self.window_area + self.class_token,) * 2 + (num_heads,) + self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,) - num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * self.class_token + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads)) self.register_buffer( "relative_position_index", - gen_relative_position_index(self.window_size, class_token=self.class_token), + gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0), persistent=False, ) @@ -306,11 +343,32 @@ class VisionTransformerRelPos(nn.Module): """ def __init__( - self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg', - embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-6, - class_token=False, fc_norm=False, rel_pos_type='mlp', shared_rel_pos=False, rel_pos_dim=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', - embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock): + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + global_pool='avg', + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + init_values=1e-6, + class_token=False, + fc_norm=False, + rel_pos_type='mlp', + rel_pos_dim=None, + shared_rel_pos=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + weight_init='skip', + embed_layer=PatchEmbed, + norm_layer=None, + act_layer=None, + block_fn=RelPosBlock + ): """ Args: img_size (int, tuple): input image size @@ -345,19 +403,22 @@ class VisionTransformerRelPos(nn.Module): self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 1 if class_token else 0 + self.num_prefix_tokens = 1 if class_token else 0 self.grad_checkpointing = False self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) feat_size = self.patch_embed.grid_size - rel_pos_args = dict(window_size=feat_size, class_token=class_token) + rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens) if rel_pos_type.startswith('mlp'): if rel_pos_dim: rel_pos_args['hidden_dim'] = rel_pos_dim + # FIXME experimenting with different relpos log coord configs if 'swin' in rel_pos_type: rel_pos_args['mode'] = 'swin' + elif 'rw' in rel_pos_type: + rel_pos_args['mode'] = 'rw' rel_pos_cls = partial(RelPosMlp, **rel_pos_args) else: rel_pos_cls = partial(RelPosBias, **rel_pos_args) @@ -367,7 +428,7 @@ class VisionTransformerRelPos(nn.Module): # NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both... rel_pos_cls = None - self.cls_token = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) if self.num_tokens else None + self.cls_token = nn.Parameter(torch.zeros(1, self.num_prefix_tokens, embed_dim)) if class_token else None dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ @@ -434,7 +495,7 @@ class VisionTransformerRelPos(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool: - x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) return x if pre_logits else self.head(x) @@ -502,6 +563,41 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs): return model +@register_model +def vit_srelpos_small_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, fc_norm=False, + rel_pos_dim=384, shared_rel_pos=True, **kwargs) + model = _create_vision_transformer_relpos('vit_srelpos_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_srelpos_medium_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=False, + rel_pos_dim=512, shared_rel_pos=True, **kwargs) + model = _create_vision_transformer_relpos( + 'vit_srelpos_medium_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_medium_patch16_cls_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-M/16) w/ relative log-coord position, class token present + """ + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=False, + rel_pos_dim=256, class_token=True, global_pool='token', **kwargs) + model = _create_vision_transformer_relpos( + 'vit_relpos_medium_patch16_cls_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present @@ -514,14 +610,14 @@ def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs): @register_model -def vit_relpos_base_patch16_gapcls_224(pretrained=False, **kwargs): +def vit_relpos_base_patch16_clsgap_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present NOTE this config is a bit of a mistake, class token was enabled but global avg-pool w/ fc-norm was not disabled Leaving here for comparisons w/ a future re-train as it performs quite well. """ model_kwargs = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, class_token=True, **kwargs) - model = _create_vision_transformer_relpos('vit_relpos_base_patch16_gapcls_224', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer_relpos('vit_relpos_base_patch16_clsgap_224', pretrained=pretrained, **model_kwargs) return model