From ca991c1fa57373286b9876aa63370fd19f5d6032 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Tue, 7 Jun 2022 18:01:52 -0700 Subject: [PATCH 01/45] add --aot-autograd --- benchmark.py | 18 +++++++++++++++--- train.py | 15 ++++++++++++++- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/benchmark.py b/benchmark.py index 422da45d..e2370dcc 100755 --- a/benchmark.py +++ b/benchmark.py @@ -51,6 +51,12 @@ except ImportError as e: FlopCountAnalysis = None has_fvcore_profiling = False +try: + from functorch.compile import memory_efficient_fusion + has_functorch = True +except ImportError as e: + has_functorch = False + torch.backends.cudnn.benchmark = True _logger = logging.getLogger('validate') @@ -95,10 +101,13 @@ parser.add_argument('--amp', action='store_true', default=False, help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.') parser.add_argument('--precision', default='float32', type=str, help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)') -parser.add_argument('--torchscript', dest='torchscript', action='store_true', - help='convert model torchscript for inference') parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") +scripting_group = parser.add_mutually_exclusive_group() +scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', + help='convert model torchscript for inference') +scripting_group.add_argument('--aot-autograd', default=False, action='store_true', + help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") # train optimizer parameters @@ -188,7 +197,7 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False class BenchmarkRunner: def __init__( - self, model_name, detail=False, device='cuda', torchscript=False, precision='float32', + self, model_name, detail=False, device='cuda', torchscript=False, aot_autograd=False, precision='float32', fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs): self.model_name = model_name self.detail = detail @@ -220,6 +229,9 @@ class BenchmarkRunner: if torchscript: self.model = torch.jit.script(self.model) self.scripted = True + if aot_autograd: + assert has_functorch, "functorch is needed for --aot-autograd" + self.model = memory_efficient_fusion(self.model) data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) self.input_size = data_config['input_size'] diff --git a/train.py b/train.py index acdf93c3..c95ec150 100755 --- a/train.py +++ b/train.py @@ -61,6 +61,13 @@ try: except ImportError: has_wandb = False +try: + from functorch.compile import memory_efficient_fusion + has_functorch = True +except ImportError as e: + has_functorch = False + + torch.backends.cudnn.benchmark = True _logger = logging.getLogger('train') @@ -123,8 +130,11 @@ group.add_argument('-vb', '--validation-batch-size', type=int, default=None, met help='Validation batch size override (default: None)') group.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') -group.add_argument('--torchscript', dest='torchscript', action='store_true', +scripting_group = group.add_mutually_exclusive_group() +scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', help='torch.jit.script the full model') +scripting_group.add_argument('--aot-autograd', default=False, action='store_true', + help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") group.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") group.add_argument('--grad-checkpointing', action='store_true', default=False, @@ -445,6 +455,9 @@ def main(): assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) + if args.aot_autograd: + assert has_functorch, "functorch is needed for --aot-autograd" + model = memory_efficient_fusion(model) optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) From 2d7ab065030462f151f09ef91f86d3f0f4e6bc62 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 9 Jun 2022 14:30:21 -0700 Subject: [PATCH 02/45] Move aot-autograd opt after model metadata used to setup data config in benchmark.py --- benchmark.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmark.py b/benchmark.py index e2370dcc..f348fcb9 100755 --- a/benchmark.py +++ b/benchmark.py @@ -229,14 +229,14 @@ class BenchmarkRunner: if torchscript: self.model = torch.jit.script(self.model) self.scripted = True - if aot_autograd: - assert has_functorch, "functorch is needed for --aot-autograd" - self.model = memory_efficient_fusion(self.model) - data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) self.input_size = data_config['input_size'] self.batch_size = kwargs.pop('batch_size', 256) + if aot_autograd: + assert has_functorch, "functorch is needed for --aot-autograd" + self.model = memory_efficient_fusion(self.model) + self.example_inputs = None self.num_warm_iter = num_warm_iter self.num_bench_iter = num_bench_iter From db64393c0d9908645fcb661769485d0e8ac9b2c5 Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Mon, 13 Jun 2022 01:30:57 -0400 Subject: [PATCH 03/45] use `Image.Resampling` namespace for PIL mapping (#1256) * use `Image.Resampling` namespace for PIL mapping PIL shows a deprecation warning when accessing resampling constants via the `Image` namespace. The suggested namespace is `Image.Resampling`. This commit updates `_pil_interpolation_to_str` to use the `Image.Resampling` namespace. ``` /tmp/ipykernel_11959/698124036.py:2: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead. Image.NEAREST: 'nearest', /tmp/ipykernel_11959/698124036.py:3: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead. Image.BILINEAR: 'bilinear', /tmp/ipykernel_11959/698124036.py:4: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead. Image.BICUBIC: 'bicubic', /tmp/ipykernel_11959/698124036.py:5: DeprecationWarning: BOX is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BOX instead. Image.BOX: 'box', /tmp/ipykernel_11959/698124036.py:6: DeprecationWarning: HAMMING is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.HAMMING instead. Image.HAMMING: 'hamming', /tmp/ipykernel_11959/698124036.py:7: DeprecationWarning: LANCZOS is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.LANCZOS instead. Image.LANCZOS: 'lanczos', ``` * use new pillow resampling enum only if it exists --- timm/data/transforms.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 45c078f3..3eb3bc32 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -35,14 +35,28 @@ class ToTensor: return torch.from_numpy(np_img).to(dtype=self.dtype) -_pil_interpolation_to_str = { - Image.NEAREST: 'nearest', - Image.BILINEAR: 'bilinear', - Image.BICUBIC: 'bicubic', - Image.BOX: 'box', - Image.HAMMING: 'hamming', - Image.LANCZOS: 'lanczos', -} +# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in +# favor of the Image.Resampling enum. The top-level resampling attributes will be +# removed in Pillow 10. +if hasattr(Image, "Resampling"): + _pil_interpolation_to_str = { + Image.Resampling.NEAREST: 'nearest', + Image.Resampling.BILINEAR: 'bilinear', + Image.Resampling.BICUBIC: 'bicubic', + Image.Resampling.BOX: 'box', + Image.Resampling.HAMMING: 'hamming', + Image.Resampling.LANCZOS: 'lanczos', + } +else: + _pil_interpolation_to_str = { + Image.NEAREST: 'nearest', + Image.BILINEAR: 'bilinear', + Image.BICUBIC: 'bicubic', + Image.BOX: 'box', + Image.HAMMING: 'hamming', + Image.LANCZOS: 'lanczos', + } + _str_to_pil_interpolation = {b: a for a, b in _pil_interpolation_to_str.items()} @@ -181,5 +195,3 @@ class RandomResizedCropAndInterpolation: format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) format_string += ', interpolation={0})'.format(interpolate_str) return format_string - - From 9e12530433f38c536e9a5fdfe5c9f455638a3e8a Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Thu, 26 May 2022 08:57:47 -0400 Subject: [PATCH 04/45] use utils namespace instead of function/classnames This fixes buggy behavior introduced by https://github.com/rwightman/pytorch-image-models/pull/1266. Related to https://github.com/rwightman/pytorch-image-models/pull/1273. --- train.py | 48 +++++++++++++++++++++++------------------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/train.py b/train.py index c95ec150..047a8256 100755 --- a/train.py +++ b/train.py @@ -31,9 +31,7 @@ from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\ convert_splitbn_model, model_parameters -from timm.utils import setup_default_logging, random_seed, set_jit_fuser, ModelEmaV2,\ - get_outdir, CheckpointSaver, distribute_bn, update_summary, accuracy, AverageMeter,\ - dispatch_clip_grad, reduce_tensor +from timm import utils from timm.loss import JsdCrossEntropy, BinaryCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy,\ LabelSmoothingCrossEntropy from timm.optim import create_optimizer_v2, optimizer_kwargs @@ -346,7 +344,7 @@ def _parse_args(): def main(): - setup_default_logging() + utils.setup_default_logging() args, args_text = _parse_args() if args.log_wandb: @@ -391,10 +389,10 @@ def main(): _logger.warning("Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") - random_seed(args.seed, args.rank) + utils.random_seed(args.seed, args.rank) if args.fuser: - set_jit_fuser(args.fuser) + utils.set_jit_fuser(args.fuser) model = create_model( args.model, @@ -492,7 +490,7 @@ def main(): model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper - model_ema = ModelEmaV2( + model_ema = utils.ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) @@ -640,9 +638,9 @@ def main(): safe_model_name(args.model), str(data_config['input_size'][-1]) ]) - output_dir = get_outdir(args.output if args.output else './output/train', exp_name) + output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False - saver = CheckpointSaver( + saver = utils.CheckpointSaver( model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: @@ -661,13 +659,13 @@ def main(): if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") - distribute_bn(model, args.world_size, args.dist_bn == 'reduce') + utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): - distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') + utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics @@ -677,7 +675,7 @@ def main(): lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) if output_dir is not None: - update_summary( + utils.update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) @@ -704,9 +702,9 @@ def train_one_epoch( mixup_fn.mixup_enabled = False second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order - batch_time_m = AverageMeter() - data_time_m = AverageMeter() - losses_m = AverageMeter() + batch_time_m = utils.AverageMeter() + data_time_m = utils.AverageMeter() + losses_m = utils.AverageMeter() model.train() @@ -740,7 +738,7 @@ def train_one_epoch( else: loss.backward(create_graph=second_order) if args.clip_grad is not None: - dispatch_clip_grad( + utils.dispatch_clip_grad( model_parameters(model, exclude_head='agc' in args.clip_mode), value=args.clip_grad, mode=args.clip_mode) optimizer.step() @@ -756,7 +754,7 @@ def train_one_epoch( lr = sum(lrl) / len(lrl) if args.distributed: - reduced_loss = reduce_tensor(loss.data, args.world_size) + reduced_loss = utils.reduce_tensor(loss.data, args.world_size) losses_m.update(reduced_loss.item(), input.size(0)) if args.local_rank == 0: @@ -801,10 +799,10 @@ def train_one_epoch( def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): - batch_time_m = AverageMeter() - losses_m = AverageMeter() - top1_m = AverageMeter() - top5_m = AverageMeter() + batch_time_m = utils.AverageMeter() + losses_m = utils.AverageMeter() + top1_m = utils.AverageMeter() + top5_m = utils.AverageMeter() model.eval() @@ -831,12 +829,12 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='') target = target[0:target.size(0):reduce_factor] loss = loss_fn(output, target) - acc1, acc5 = accuracy(output, target, topk=(1, 5)) + acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) if args.distributed: - reduced_loss = reduce_tensor(loss.data, args.world_size) - acc1 = reduce_tensor(acc1, args.world_size) - acc5 = reduce_tensor(acc5, args.world_size) + reduced_loss = utils.reduce_tensor(loss.data, args.world_size) + acc1 = utils.reduce_tensor(acc1, args.world_size) + acc5 = utils.reduce_tensor(acc5, args.world_size) else: reduced_loss = loss.data From 037e5e6c09d89df2a31b953fcdfc3d4201f45a54 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 21 Jun 2022 12:31:48 -0700 Subject: [PATCH 05/45] Fix #1309, move wandb init after distributed init, only init on rank == 0 process --- train.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index 047a8256..444be066 100755 --- a/train.py +++ b/train.py @@ -347,13 +347,6 @@ def main(): utils.setup_default_logging() args, args_text = _parse_args() - if args.log_wandb: - if has_wandb: - wandb.init(project=args.experiment, config=args) - else: - _logger.warning("You've requested to log metrics to wandb but package not found. " - "Metrics not being logged to wandb, try `pip install wandb`") - args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: @@ -373,6 +366,13 @@ def main(): _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 + if args.rank == 0 and args.log_wandb: + if has_wandb: + wandb.init(project=args.experiment, config=args) + else: + _logger.warning("You've requested to log metrics to wandb but package not found. " + "Metrics not being logged to wandb, try `pip install wandb`") + # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: From 7cedc8d4743f2b2bbf835fc387c917461fa4911a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 21 Jun 2022 14:56:53 -0700 Subject: [PATCH 06/45] Follow up to #1256, fix interpolation warning in auto_autoaugment as well --- timm/data/auto_augment.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index 121a3fc6..1b51ccb4 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -36,11 +36,16 @@ _HPARAMS_DEFAULT = dict( img_mean=_FILL, ) -_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) +if hasattr(Image, "Resampling"): + _RANDOM_INTERPOLATION = (Image.Resampling.BILINEAR, Image.Resampling.BICUBIC) + _DEFAULT_INTERPOLATION = Image.Resampling.BICUBIC +else: + _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + _DEFAULT_INTERPOLATION = Image.BICUBIC def _interpolation(kwargs): - interpolation = kwargs.pop('resample', Image.BILINEAR) + interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION) if isinstance(interpolation, (list, tuple)): return random.choice(interpolation) else: From 879df47c0a7e8108545a1ca1fcbfa88ca2714778 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jun 2022 14:51:26 -0700 Subject: [PATCH 07/45] Support BatchNormAct2d for sync-bn use. Fix #1254 --- timm/models/__init__.py | 2 +- timm/models/layers/__init__.py | 2 +- timm/models/layers/norm_act.py | 90 ++++++++++++++++++++++++++++++++-- train.py | 19 +++---- 4 files changed, 97 insertions(+), 16 deletions(-) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 8cb6c70a..4f81683a 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -61,7 +61,7 @@ from .xcit import * from .factory import create_model, parse_model_name, safe_model_name from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .layers import TestTimePoolHead, apply_test_time_pool -from .layers import convert_splitbn_model +from .layers import convert_splitbn_model, convert_sync_batchnorm from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 7e9e7b19..b1a64db3 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -26,7 +26,7 @@ from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .norm import GroupNorm, LayerNorm2d -from .norm_act import BatchNormAct2d, GroupNormAct +from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm from .padding import get_padding, get_same_padding, pad_same from .patch_embed import PatchEmbed from .pool2d_same import AvgPool2dSame, create_pool2d diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index 34c4fd64..261bdb0a 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -1,10 +1,15 @@ """ Normalization + Activation Layers """ -from typing import Union, List +from typing import Union, List, Optional, Any import torch from torch import nn as nn from torch.nn import functional as F +try: + from torch.nn.modules._functions import SyncBatchNorm as sync_batch_norm + FULL_SYNC_BN = True +except ImportError: + FULL_SYNC_BN = False from .trace_utils import _assert from .create_act import get_act_layer @@ -18,10 +23,29 @@ class BatchNormAct2d(nn.BatchNorm2d): instead of composing it as a .bn member. """ def __init__( - self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, - apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): - super(BatchNormAct2d, self).__init__( - num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + apply_act=True, + act_layer=nn.ReLU, + inplace=True, + drop_layer=None, + device=None, + dtype=None + ): + try: + factory_kwargs = {'device': device, 'dtype': dtype} + super(BatchNormAct2d, self).__init__( + num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats, + **factory_kwargs + ) + except TypeError: + # NOTE for backwards compat with old PyTorch w/o factory device/dtype support + super(BatchNormAct2d, self).__init__( + num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) self.drop = drop_layer() if drop_layer is not None else nn.Identity() act_layer = get_act_layer(act_layer) # string -> nn.Module if act_layer is not None and apply_act: @@ -81,6 +105,62 @@ class BatchNormAct2d(nn.BatchNorm2d): return x +class SyncBatchNormAct(nn.SyncBatchNorm): + # Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254) + # This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers + # but ONLY when used in conjunction with the timm conversion function below. + # Do not create this module directly or use the PyTorch conversion function. + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = super().forward(x) # SyncBN doesn't work with torchscript anyways, so this is fine + if hasattr(self, "drop"): + x = self.drop(x) + if hasattr(self, "act"): + x = self.act(x) + return x + + +def convert_sync_batchnorm(module, process_group=None): + # convert both BatchNorm and BatchNormAct layers to Synchronized variants + module_output = module + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + if isinstance(module, BatchNormAct2d): + # convert timm norm + act layer + module_output = SyncBatchNormAct( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + process_group=process_group, + ) + # set act and drop attr from the original module + module_output.act = module.act + module_output.drop = module.drop + else: + # convert standard BatchNorm layers + module_output = torch.nn.SyncBatchNorm( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + process_group, + ) + if module.affine: + with torch.no_grad(): + module_output.weight = module.weight + module_output.bias = module.bias + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + if hasattr(module, "qconfig"): + module_output.qconfig = module.qconfig + for name, child in module.named_children(): + module_output.add_module(name, convert_sync_batchnorm(child, process_group)) + del module + return module_output + + def _num_groups(num_channels, num_groups, group_size): if group_size: assert num_channels % group_size == 0 diff --git a/train.py b/train.py index 444be066..2a68e05e 100755 --- a/train.py +++ b/train.py @@ -15,10 +15,9 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) """ import argparse -import time -import yaml -import os import logging +import os +import time from collections import OrderedDict from contextlib import suppress from datetime import datetime @@ -26,14 +25,15 @@ from datetime import datetime import torch import torch.nn as nn import torchvision.utils +import yaml from torch.nn.parallel import DistributedDataParallel as NativeDDP -from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset -from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\ - convert_splitbn_model, model_parameters from timm import utils -from timm.loss import JsdCrossEntropy, BinaryCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy,\ +from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset +from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \ LabelSmoothingCrossEntropy +from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \ + convert_splitbn_model, convert_sync_batchnorm, model_parameters from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler from timm.utils import ApexScaler, NativeScaler @@ -440,10 +440,11 @@ def main(): if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp == 'apex': - # Apex SyncBN preferred unless native amp is activated + # Apex SyncBN used with Apex AMP + # WARNING this won't currently work with models using BatchNormAct2d model = convert_syncbn_model(model) else: - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' From 7d657d2ef45fc841f3a987ca0f18868686dbecf5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jun 2022 14:55:25 -0700 Subject: [PATCH 08/45] Improve resolve_pretrained_cfg behaviour when no cfg exists, warn instead of crash. Improve usability ex #1311 --- timm/models/helpers.py | 27 ++++++++++++++++-------- timm/models/inception_v3.py | 2 +- timm/models/vision_transformer.py | 2 +- timm/models/vision_transformer_relpos.py | 2 +- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 1276b68e..11630bb6 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -455,18 +455,27 @@ def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter): filter_kwargs(kwargs, names=kwargs_filter) -def resolve_pretrained_cfg(variant: str, pretrained_cfg=None, kwargs=None): +def resolve_pretrained_cfg(variant: str, **kwargs): + pretrained_cfg = kwargs.pop('pretrained_cfg', None) if pretrained_cfg and isinstance(pretrained_cfg, dict): - # highest priority, pretrained_cfg available and passed explicitly + # highest priority, pretrained_cfg available and passed in args return deepcopy(pretrained_cfg) - if kwargs and 'pretrained_cfg' in kwargs: - # next highest, pretrained_cfg in a kwargs dict, pop and return - pretrained_cfg = kwargs.pop('pretrained_cfg', {}) - if pretrained_cfg: - return deepcopy(pretrained_cfg) - # lookup pretrained cfg in model registry by variant + # fallback to looking up pretrained cfg in model registry by variant identifier pretrained_cfg = get_pretrained_cfg(variant) - assert pretrained_cfg + if not pretrained_cfg: + _logger.warning( + f"No pretrained configuration specified for {variant} model. Using a default." + f" Please add a config to the model pretrained_cfg registry or pass explicitly.") + pretrained_cfg = dict( + url='', + num_classes=1000, + input_size=(3, 224, 224), + pool_size=None, + crop_pct=.9, + interpolation='bicubic', + first_conv='', + classifier='', + ) return pretrained_cfg diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index e34de657..2c6e7eb7 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -428,7 +428,7 @@ class InceptionV3Aux(InceptionV3): def _create_inception_v3(variant, pretrained=False, **kwargs): - pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) + pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) aux_logits = kwargs.pop('aux_logits', False) if aux_logits: assert not kwargs.pop('features_only', False) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 59fd7849..8551feae 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -633,7 +633,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') - pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) + pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) model = build_model_with_cfg( VisionTransformer, variant, pretrained, pretrained_cfg=pretrained_cfg, diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 0c2ac376..0c9ac989 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -16,7 +16,7 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply +from .helpers import build_model_with_cfg, named_apply from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple from .registry import register_model From 0da3c9ebbf483f76f51a366bebfd3d3589f74a0c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jun 2022 14:56:58 -0700 Subject: [PATCH 09/45] Remove SiLU layer in default args that breaks import on old old PyTorch --- timm/models/layers/evo_norm.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index b643302c..ea776207 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -256,8 +256,9 @@ class EvoNorm2dS0a(EvoNorm2dS0): class EvoNorm2dS1(nn.Module): def __init__( self, num_features, groups=32, group_size=None, - apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): + apply_act=True, act_layer=None, eps=1e-5, **_): super().__init__() + act_layer = act_layer or nn.SiLU self.apply_act = apply_act # apply activation (non-linearity) if act_layer is not None and apply_act: self.act = create_act_layer(act_layer) @@ -290,7 +291,7 @@ class EvoNorm2dS1(nn.Module): class EvoNorm2dS1a(EvoNorm2dS1): def __init__( self, num_features, groups=32, group_size=None, - apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_): + apply_act=True, act_layer=None, eps=1e-3, **_): super().__init__( num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) @@ -305,8 +306,9 @@ class EvoNorm2dS1a(EvoNorm2dS1): class EvoNorm2dS2(nn.Module): def __init__( self, num_features, groups=32, group_size=None, - apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): + apply_act=True, act_layer=None, eps=1e-5, **_): super().__init__() + act_layer = act_layer or nn.SiLU self.apply_act = apply_act # apply activation (non-linearity) if act_layer is not None and apply_act: self.act = create_act_layer(act_layer) @@ -338,7 +340,7 @@ class EvoNorm2dS2(nn.Module): class EvoNorm2dS2a(EvoNorm2dS2): def __init__( self, num_features, groups=32, group_size=None, - apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_): + apply_act=True, act_layer=None, eps=1e-3, **_): super().__init__( num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) From e27c16b8a02eaeae2d42d969f4b137e8e940dbe1 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jun 2022 14:57:42 -0700 Subject: [PATCH 10/45] Remove unecessary code for synbn guard --- timm/models/layers/norm_act.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index 261bdb0a..ea5b7883 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -5,11 +5,6 @@ from typing import Union, List, Optional, Any import torch from torch import nn as nn from torch.nn import functional as F -try: - from torch.nn.modules._functions import SyncBatchNorm as sync_batch_norm - FULL_SYNC_BN = True -except ImportError: - FULL_SYNC_BN = False from .trace_utils import _assert from .create_act import get_act_layer From 07d0c4ae963481a225ce8243f047408d96f13fab Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jun 2022 14:58:15 -0700 Subject: [PATCH 11/45] Improve repr for DropPath module --- timm/models/layers/drop.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/timm/models/layers/drop.py b/timm/models/layers/drop.py index ae065277..1ab1c8f5 100644 --- a/timm/models/layers/drop.py +++ b/timm/models/layers/drop.py @@ -164,3 +164,6 @@ class DropPath(nn.Module): def forward(self, x): return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob,3):0.3f}' From a29fba307dc544ac1d4eeab0043c3aae8d5aa7c8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jun 2022 21:30:17 -0700 Subject: [PATCH 12/45] disable dist_bn when sync_bn active --- train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train.py b/train.py index 2a68e05e..285981fd 100755 --- a/train.py +++ b/train.py @@ -438,6 +438,7 @@ def main(): # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: + args.dist_bn = '' # disable dist_bn when sync BN active assert not args.split_bn if has_apex and use_amp == 'apex': # Apex SyncBN used with Apex AMP From e6d7df40ecb9595b9e4f2f5b69f2633f81cee9ce Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jun 2022 21:32:44 -0700 Subject: [PATCH 13/45] no longer a point using kwargs for pretrain_cfg resolve, just pass explicit arg --- timm/models/helpers.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 11630bb6..fda84171 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -455,10 +455,9 @@ def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter): filter_kwargs(kwargs, names=kwargs_filter) -def resolve_pretrained_cfg(variant: str, **kwargs): - pretrained_cfg = kwargs.pop('pretrained_cfg', None) +def resolve_pretrained_cfg(variant: str, pretrained_cfg=None): if pretrained_cfg and isinstance(pretrained_cfg, dict): - # highest priority, pretrained_cfg available and passed in args + # highest priority, pretrained_cfg available and passed as arg return deepcopy(pretrained_cfg) # fallback to looking up pretrained cfg in model registry by variant identifier pretrained_cfg = get_pretrained_cfg(variant) From 34f382f8f6583a80cb0a169c275bf0806d95ca06 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 1 Jul 2022 14:50:36 -0700 Subject: [PATCH 14/45] move dataconfig before script, scripting killing metadata now (PyTorch 1.12? just nvfuser?) --- benchmark.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmark.py b/benchmark.py index f348fcb9..1362eeab 100755 --- a/benchmark.py +++ b/benchmark.py @@ -225,11 +225,12 @@ class BenchmarkRunner: self.num_classes = self.model.num_classes self.param_count = count_params(self.model) _logger.info('Model %s created, param count: %d' % (model_name, self.param_count)) + + data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) self.scripted = False if torchscript: self.model = torch.jit.script(self.model) self.scripted = True - data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) self.input_size = data_config['input_size'] self.batch_size = kwargs.pop('batch_size', 256) From a050fde5cde892404a5b77973a5916cdd7b602ab Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 1 Jul 2022 15:03:28 -0700 Subject: [PATCH 15/45] Add resnet10t (basic block) and resnet14t (bottleneck) with 1,1,1,1 repeats --- timm/models/resnet.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index a7f0c0f6..476ffe91 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -723,6 +723,24 @@ def _create_resnet(variant, pretrained=False, **kwargs): return build_model_with_cfg(ResNet, variant, pretrained, **kwargs) +@register_model +def resnet10t(pretrained=False, **kwargs): + """Constructs a ResNet-10-T model. + """ + model_args = dict( + block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) + return _create_resnet('resnet10t', pretrained, **model_args) + + +@register_model +def resnet14t(pretrained=False, **kwargs): + """Constructs a ResNet-14-T model. + """ + model_args = dict( + block=Bottleneck, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) + return _create_resnet('resnet14t', pretrained, **model_args) + + @register_model def resnet18(pretrained=False, **kwargs): """Constructs a ResNet-18 model. From 82c311d0821643e8613b0b18f6d0f14088a79459 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 1 Jul 2022 15:14:01 -0700 Subject: [PATCH 16/45] Add more experimental darknet and 'cs2' darknet variants (different cross stage setup, closer to newer YOLO backbones) for train trials. --- timm/models/cspnet.py | 384 ++++++++++++++++++++++++++---- timm/models/layers/conv_bn_act.py | 19 +- 2 files changed, 352 insertions(+), 51 deletions(-) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index f8a87fab..095e4701 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -16,6 +16,7 @@ from functools import partial import torch import torch.nn as nn +import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP @@ -46,11 +47,21 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnext50_ra_224-648b4713.pth', input_size=(3, 224, 224), pool_size=(7, 7), crop_pct=0.875 # FIXME I trained this at 224x224, not 256 like ref impl ), - 'cspresnext50_iabn': _cfg(url=''), 'cspdarknet53': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspdarknet53_ra_256-d05c7c21.pth'), - 'cspdarknet53_iabn': _cfg(url=''), + + 'darknet17': _cfg(url=''), + 'darknet21': _cfg(url=''), 'darknet53': _cfg(url=''), + + 'cs2darknet_m': _cfg( + url=''), + 'cs2darknet_l': _cfg( + url=''), + 'cs2darknet_f_m': _cfg( + url=''), + 'cs2darknet_f_l': _cfg( + url=''), } @@ -116,6 +127,37 @@ model_cfgs = dict( down_growth=True, ) ), + darknet17=dict( + stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), + stage=dict( + out_chs=(64, 128, 256, 512, 1024), + depth=(1,) * 5, + stride=(2,) * 5, + bottle_ratio=(0.5,) * 5, + block_ratio=(1.,) * 5, + ) + ), + darknet21=dict( + stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), + stage=dict( + out_chs=(64, 128, 256, 512, 1024), + depth=(1, 1, 1, 2, 2), + stride=(2,) * 5, + bottle_ratio=(0.5,) * 5, + block_ratio=(1.,) * 5, + ) + ), + sedarknet21=dict( + stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), + stage=dict( + out_chs=(64, 128, 256, 512, 1024), + depth=(1, 1, 1, 2, 2), + stride=(2,) * 5, + bottle_ratio=(0.5,) * 5, + block_ratio=(1.,) * 5, + attn_layer=('se',) * 5, + ) + ), darknet53=dict( stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), stage=dict( @@ -125,13 +167,81 @@ model_cfgs = dict( bottle_ratio=(0.5,) * 5, block_ratio=(1.,) * 5, ) + ), + + darknetaa53=dict( + stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), + stage=dict( + out_chs=(64, 128, 256, 512, 1024), + depth=(1, 2, 8, 8, 4), + stride=(2,) * 5, + bottle_ratio=(0.5,) * 5, + block_ratio=(1.,) * 5, + avg_down=True, + ), + ), + + cs2darknet_m=dict( + stem=dict(out_chs=(24, 48), kernel_size=3, stride=2, pool=''), + stage=dict( + out_chs=(96, 192, 384, 768), + depth=(2, 4, 6, 2), + stride=(2,) * 4, + bottle_ratio=(1.,) * 4, + block_ratio=(0.5,) * 4, + avg_down=False, + ), + ), + + cs2darknet_f_m=dict( + stem=dict(out_chs=48, kernel_size=6, stride=2, padding=2, pool=''), + stage=dict( + out_chs=(96, 192, 384, 768), + depth=(2, 4, 6, 2), + stride=(2,) * 4, + bottle_ratio=(1.,) * 4, + block_ratio=(0.5,) * 4, + avg_down=False, + ), + ), + + cs2darknet_l=dict( + stem=dict(out_chs=(32, 64), kernel_size=3, stride=2, pool=''), + stage=dict( + out_chs=(128, 256, 512, 1024), + depth=(3, 6, 9, 3), + stride=(2,) * 4, + bottle_ratio=(1.,) * 4, + block_ratio=(0.5,) * 4, + avg_down=False, + ), + ), + + cs2darknet_f_l=dict( + stem=dict(out_chs=64, kernel_size=6, stride=2, padding=2, pool=''), + stage=dict( + out_chs=(128, 256, 512, 1024), + depth=(3, 6, 9, 3), + stride=(2,) * 4, + bottle_ratio=(1.,) * 4, + block_ratio=(0.5,) * 4, + avg_down=False, + ), ) ) def create_stem( - in_chans=3, out_chs=32, kernel_size=3, stride=2, pool='', - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None): + in_chans=3, + out_chs=32, + kernel_size=3, + stride=2, + pool='', + padding='', + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + aa_layer=None +): stem = nn.Sequential() if not isinstance(out_chs, (tuple, list)): out_chs = [out_chs] @@ -140,8 +250,12 @@ def create_stem( for i, out_c in enumerate(out_chs): conv_name = f'conv{i + 1}' stem.add_module(conv_name, ConvNormAct( - in_c, out_c, kernel_size, stride=stride if i == 0 else 1, - act_layer=act_layer, norm_layer=norm_layer)) + in_c, out_c, kernel_size, + stride=stride if i == 0 else 1, + padding=padding if i == 0 else '', + act_layer=act_layer, + norm_layer=norm_layer + )) in_c = out_c last_conv = conv_name if pool: @@ -158,9 +272,20 @@ class ResBottleneck(nn.Module): """ def __init__( - self, in_chs, out_chs, dilation=1, bottle_ratio=0.25, groups=1, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_last=False, - attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + self, + in_chs, + out_chs, + dilation=1, + bottle_ratio=0.25, + groups=1, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + attn_last=False, + attn_layer=None, + aa_layer=None, + drop_block=None, + drop_path=None + ): super(ResBottleneck, self).__init__() mid_chs = int(round(out_chs * bottle_ratio)) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) @@ -173,7 +298,7 @@ class ResBottleneck(nn.Module): self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs) self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None self.drop_path = drop_path - self.act3 = act_layer(inplace=True) + self.act3 = act_layer() def zero_init_last(self): nn.init.zeros_(self.conv3.bn.weight) @@ -201,9 +326,19 @@ class DarkBlock(nn.Module): """ def __init__( - self, in_chs, out_chs, dilation=1, bottle_ratio=0.5, groups=1, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, - drop_block=None, drop_path=None): + self, + in_chs, + out_chs, + dilation=1, + bottle_ratio=0.5, + groups=1, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + attn_layer=None, + aa_layer=None, + drop_block=None, + drop_path=None + ): super(DarkBlock, self).__init__() mid_chs = int(round(out_chs * bottle_ratio)) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) @@ -211,7 +346,7 @@ class DarkBlock(nn.Module): self.conv2 = ConvNormActAa( mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups, aa_layer=aa_layer, drop_layer=drop_block, **ckwargs) - self.attn = create_attn(attn_layer, channels=out_chs) + self.attn = create_attn(attn_layer, channels=out_chs, act_layer=act_layer) self.drop_path = drop_path def zero_init_last(self): @@ -232,23 +367,44 @@ class DarkBlock(nn.Module): class CrossStage(nn.Module): """Cross Stage.""" def __init__( - self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1., - groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None, - block_fn=ResBottleneck, **block_kwargs): + self, + in_chs, + out_chs, + stride, + dilation, + depth, + block_ratio=1., + bottle_ratio=1., + exp_ratio=1., + groups=1, + first_dilation=None, + avg_down=False, + down_growth=False, + cross_linear=False, + block_dpr=None, + block_fn=ResBottleneck, + **block_kwargs + ): super(CrossStage, self).__init__() first_dilation = first_dilation or dilation down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels - exp_chs = int(round(out_chs * exp_ratio)) + self.exp_chs = exp_chs = int(round(out_chs * exp_ratio)) block_out_chs = int(round(out_chs * block_ratio)) conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) if stride != 1 or first_dilation != dilation: - self.conv_down = ConvNormActAa( - in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, - aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs) + if avg_down: + self.conv_down = nn.Sequential( + nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling + ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) + ) + else: + self.conv_down = ConvNormActAa( + in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, + aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs) prev_chs = down_chs else: - self.conv_down = None + self.conv_down = nn.Identity() prev_chs = in_chs # FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also, @@ -269,30 +425,115 @@ class CrossStage(nn.Module): self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs) def forward(self, x): - if self.conv_down is not None: - x = self.conv_down(x) + x = self.conv_down(x) x = self.conv_exp(x) - split = x.shape[1] // 2 - xs, xb = x[:, :split], x[:, split:] + xs, xb = x.split(self.exp_chs // 2, dim=1) xb = self.blocks(xb) xb = self.conv_transition_b(xb).contiguous() out = self.conv_transition(torch.cat([xs, xb], dim=1)) return out +class CrossStage2(nn.Module): + """Cross Stage v2. + Similar to CrossStage, but with one transition conv for the concat output. + """ + def __init__( + self, + in_chs, + out_chs, + stride, + dilation, + depth, + block_ratio=1., + bottle_ratio=1., + exp_ratio=1., + groups=1, + first_dilation=None, + avg_down=False, + down_growth=False, + cross_linear=False, + block_dpr=None, + block_fn=ResBottleneck, + **block_kwargs + ): + super(CrossStage2, self).__init__() + first_dilation = first_dilation or dilation + down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels + self.exp_chs = exp_chs = int(round(out_chs * exp_ratio)) + block_out_chs = int(round(out_chs * block_ratio)) + conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) + + if stride != 1 or first_dilation != dilation: + if avg_down: + self.conv_down = nn.Sequential( + nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling + ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) + ) + else: + self.conv_down = ConvNormActAa( + in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, + aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs) + prev_chs = down_chs + else: + self.conv_down = None + prev_chs = in_chs + + # expansion conv + self.conv_exp = ConvNormAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs) + prev_chs = exp_chs // 2 # expanded output is split in 2 for blocks and cross stage + + self.blocks = nn.Sequential() + for i in range(depth): + drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None + self.blocks.add_module(str(i), block_fn( + prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + prev_chs = block_out_chs + + # transition convs + self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs) + + def forward(self, x): + x = self.conv_down(x) + x = self.conv_exp(x) + x1, x2 = x.split(self.exp_chs // 2, dim=1) + x1 = self.blocks(x1) + out = self.conv_transition(torch.cat([x1, x2], dim=1)) + return out + + class DarkStage(nn.Module): """DarkNet stage.""" def __init__( - self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., groups=1, - first_dilation=None, block_fn=ResBottleneck, block_dpr=None, **block_kwargs): + self, + in_chs, + out_chs, + stride, + dilation, + depth, + block_ratio=1., + bottle_ratio=1., + groups=1, + first_dilation=None, + avg_down=False, + block_fn=ResBottleneck, + block_dpr=None, + **block_kwargs + ): super(DarkStage, self).__init__() first_dilation = first_dilation or dilation + conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) - self.conv_down = ConvNormActAa( - in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, - act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'), - aa_layer=block_kwargs.get('aa_layer', None)) + if avg_down: + self.conv_down = nn.Sequential( + nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling + ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) + ) + else: + self.conv_down = ConvNormActAa( + in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, + aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs) prev_chs = out_chs block_out_chs = int(round(out_chs * block_ratio)) @@ -318,6 +559,8 @@ def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.): cfg['down_growth'] = (cfg['down_growth'],) * num_stages if 'cross_linear' in cfg and not isinstance(cfg['cross_linear'], (list, tuple)): cfg['cross_linear'] = (cfg['cross_linear'],) * num_stages + if 'avg_down' in cfg and not isinstance(cfg['avg_down'], (list, tuple)): + cfg['avg_down'] = (cfg['avg_down'],) * num_stages cfg['block_dpr'] = [None] * num_stages if not drop_path_rate else \ [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg['depth'])).split(cfg['depth'])] stage_strides = [] @@ -352,9 +595,20 @@ class CspNet(nn.Module): """ def __init__( - self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0., - act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_path_rate=0., - zero_init_last=True, stage_fn=CrossStage, block_fn=ResBottleneck): + self, + cfg, + in_chans=3, + num_classes=1000, + output_stride=32, + global_pool='avg', + act_layer=nn.LeakyReLU, + norm_layer=nn.BatchNorm2d, + aa_layer=None, + drop_rate=0., + drop_path_rate=0., + zero_init_last=True, + stage_fn=CrossStage, + block_fn=ResBottleneck): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate @@ -427,23 +681,22 @@ class CspNet(nn.Module): def _init_weights(module, name, zero_init_last=False): if isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(module, nn.BatchNorm2d): - nn.init.ones_(module.weight) - nn.init.zeros_(module.bias) + if module.bias is not None: + nn.init.zeros_(module.bias) elif isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.01) - nn.init.zeros_(module.bias) + if module.bias is not None: + nn.init.zeros_(module.bias) elif zero_init_last and hasattr(module, 'zero_init_last'): module.zero_init_last() def _create_cspnet(variant, pretrained=False, **kwargs): - cfg_variant = variant.split('_')[0] # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5] out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4, 5) if 'darknet' in variant else (0, 1, 2, 3, 4)) return build_model_with_cfg( CspNet, variant, pretrained, - model_cfg=model_cfgs[cfg_variant], + model_cfg=model_cfgs[variant], feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs) @@ -469,22 +722,55 @@ def cspresnext50(pretrained=False, **kwargs): @register_model -def cspresnext50_iabn(pretrained=False, **kwargs): - norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu') - return _create_cspnet('cspresnext50_iabn', pretrained=pretrained, norm_layer=norm_layer, **kwargs) +def cspdarknet53(pretrained=False, **kwargs): + return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs) @register_model -def cspdarknet53(pretrained=False, **kwargs): - return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs) +def darknet17(pretrained=False, **kwargs): + return _create_cspnet('darknet17', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) @register_model -def cspdarknet53_iabn(pretrained=False, **kwargs): - norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu') - return _create_cspnet('cspdarknet53_iabn', pretrained=pretrained, block_fn=DarkBlock, norm_layer=norm_layer, **kwargs) +def darknet21(pretrained=False, **kwargs): + return _create_cspnet('darknet21', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + + +@register_model +def sedarknet21(pretrained=False, **kwargs): + return _create_cspnet('sedarknet21', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) @register_model def darknet53(pretrained=False, **kwargs): return _create_cspnet('darknet53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + + +@register_model +def darknetaa53(pretrained=False, **kwargs): + return _create_cspnet( + 'darknetaa53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + + +@register_model +def cs2darknet_m(pretrained=False, **kwargs): + return _create_cspnet( + 'cs2darknet_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs) + + +@register_model +def cs2darknet_l(pretrained=False, **kwargs): + return _create_cspnet( + 'cs2darknet_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs) + + +@register_model +def cs2darknet_f_m(pretrained=False, **kwargs): + return _create_cspnet( + 'cs2darknet_f_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs) + + +@register_model +def cs2darknet_f_l(pretrained=False, **kwargs): + return _create_cspnet( + 'cs2darknet_f_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs) \ No newline at end of file diff --git a/timm/models/layers/conv_bn_act.py b/timm/models/layers/conv_bn_act.py index af010573..9e7c64b8 100644 --- a/timm/models/layers/conv_bn_act.py +++ b/timm/models/layers/conv_bn_act.py @@ -2,6 +2,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ +import functools from torch import nn as nn from .create_conv2d import create_conv2d @@ -40,12 +41,26 @@ class ConvNormAct(nn.Module): ConvBnAct = ConvNormAct +def create_aa(aa_layer, channels, stride=2, enable=True): + if not aa_layer or not enable: + return nn.Identity() + if isinstance(aa_layer, functools.partial): + if issubclass(aa_layer.func, nn.AvgPool2d): + return aa_layer() + else: + return aa_layer(channels) + elif issubclass(aa_layer, nn.AvgPool2d): + return aa_layer(stride) + else: + return aa_layer(channels=channels, stride=stride) + + class ConvNormActAa(nn.Module): def __init__( self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None): super(ConvNormActAa, self).__init__() - use_aa = aa_layer is not None + use_aa = aa_layer is not None and stride == 2 self.conv = create_conv2d( in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, @@ -56,7 +71,7 @@ class ConvNormActAa(nn.Module): # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) - self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else nn.Identity() + self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) @property def in_channels(self): From 7a9c6811c91123f84af963e5302a9d18c7c33716 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 1 Jul 2022 15:15:39 -0700 Subject: [PATCH 17/45] Add eps arg to LayerNorm2d, add 'tf' (tensorflow) variant of trunc_normal_ that applies scale/shift after sampling (instead of needing to move a/b) --- timm/models/layers/__init__.py | 2 +- timm/models/layers/norm.py | 4 ++-- timm/models/layers/weight_init.py | 36 ++++++++++++++++++++++++++++++- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index b1a64db3..b1f452ff 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -39,4 +39,4 @@ from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .trace_utils import _assert, _float_to_int -from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_ +from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index 85297420..345f67bc 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -16,8 +16,8 @@ class GroupNorm(nn.GroupNorm): class LayerNorm2d(nn.LayerNorm): """ LayerNorm for channels of '2D' spatial BCHW tensors """ - def __init__(self, num_channels): - super().__init__(num_channels) + def __init__(self, num_channels, eps=1e-6): + super().__init__(num_channels, eps=eps) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.layer_norm( diff --git a/timm/models/layers/weight_init.py b/timm/models/layers/weight_init.py index 305a2fd0..4a160931 100644 --- a/timm/models/layers/weight_init.py +++ b/timm/models/layers/weight_init.py @@ -49,6 +49,11 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. + + NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are + applied while sampling the normal with mean/std applied, therefore a, b args + should be adjusted to match the range of mean, std args. + Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution @@ -62,6 +67,35 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): return _no_grad_trunc_normal_(tensor, mean, std, a, b) +def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsquently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + _no_grad_trunc_normal_(tensor, 0, 1.0, a, b) + with torch.no_grad(): + tensor.mul_(std).add_(mean) + return tensor + + def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == 'fan_in': @@ -75,7 +109,7 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) - trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978) elif distribution == "normal": tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": From 6064d16a2dfe89b1d3706df338cecfdcee395d1f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 1 Jul 2022 15:16:41 -0700 Subject: [PATCH 18/45] Add initial EdgeNeXt import. Significant cleanup / reorg (like ConvNeXt). Fix #1320 * edgenext refactored for torchscript compat, stage base organization * slight refactor of ConvNeXt to match some EdgeNeXt additions * remove use of funky LayerNorm layer in ConvNeXt and just use nn.LayerNorm and LayerNorm2d (permute) --- timm/models/__init__.py | 1 + timm/models/convnext.py | 190 ++++++++------ timm/models/edgenext.py | 545 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 665 insertions(+), 71 deletions(-) create mode 100644 timm/models/edgenext.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 4f81683a..195e451b 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -12,6 +12,7 @@ from .deit import * from .densenet import * from .dla import * from .dpn import * +from .edgenext import * from .efficientnet import * from .ghostnet import * from .gluon_resnet import * diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 1aacef2b..662695c7 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -19,7 +19,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .fx_features import register_notrace_module from .helpers import named_apply, build_model_with_cfg, checkpoint_seq -from .layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp +from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d from .registry import register_model @@ -44,6 +44,7 @@ default_cfgs = dict( convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), convnext_nano_hnf=_cfg(url=''), + convnext_nano_ols=_cfg(url=''), convnext_tiny_hnf=_cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', crop_pct=0.95), @@ -88,35 +89,6 @@ default_cfgs = dict( ) -def _is_contiguous(tensor: torch.Tensor) -> bool: - # jit is oh so lovely :/ - # if torch.jit.is_tracing(): - # return True - if torch.jit.is_scripting(): - return tensor.is_contiguous() - else: - return tensor.is_contiguous(memory_format=torch.contiguous_format) - - -@register_notrace_module -class LayerNorm2d(nn.LayerNorm): - r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). - """ - - def __init__(self, normalized_shape, eps=1e-6): - super().__init__(normalized_shape, eps=eps) - - def forward(self, x) -> torch.Tensor: - if _is_contiguous(x): - return F.layer_norm( - x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) - else: - s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) - x = (x - u) * torch.rsqrt(s + self.eps) - x = x * self.weight[:, None, None] + self.bias[:, None, None] - return x - - class ConvNeXtBlock(nn.Module): """ ConvNeXt Block There are two equivalent implementations: @@ -133,21 +105,39 @@ class ConvNeXtBlock(nn.Module): ls_init_value (float): Init value for Layer Scale. Default: 1e-6. """ - def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=False, mlp_ratio=4, norm_layer=None): + def __init__( + self, + dim, + dim_out=None, + stride=1, + mlp_ratio=4, + conv_mlp=False, + conv_bias=True, + ls_init_value=1e-6, + norm_layer=None, + act_layer=nn.GELU, + drop_path=0., + ): super().__init__() + dim_out = dim_out or dim if not norm_layer: norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) mlp_layer = ConvMlp if conv_mlp else Mlp self.use_conv_mlp = conv_mlp - self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv - self.norm = norm_layer(dim) - self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=nn.GELU) - self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None + self.shortcut_after_dw = stride > 1 + + self.conv_dw = create_conv2d(dim, dim_out, kernel_size=7, stride=stride, depthwise=True, bias=conv_bias) + self.norm = norm_layer(dim_out) + self.mlp = mlp_layer(dim_out, int(mlp_ratio * dim_out), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): shortcut = x x = self.conv_dw(x) + if self.shortcut_after_dw: + shortcut = x + if self.use_conv_mlp: x = self.norm(x) x = self.mlp(x) @@ -158,32 +148,55 @@ class ConvNeXtBlock(nn.Module): x = x.permute(0, 3, 1, 2) if self.gamma is not None: x = x.mul(self.gamma.reshape(1, -1, 1, 1)) + x = self.drop_path(x) + shortcut + #print('b', x.shape) return x class ConvNeXtStage(nn.Module): def __init__( - self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=False, - norm_layer=None, cl_norm_layer=None, cross_stage=False): + self, + in_chs, + out_chs, + stride=2, + depth=2, + drop_path_rates=None, + ls_init_value=1.0, + downsample_block=False, + conv_mlp=False, + conv_bias=True, + norm_layer=None, + norm_layer_cl=None + ): super().__init__() self.grad_checkpointing = False - if in_chs != out_chs or stride > 1: + if downsample_block or (in_chs == out_chs and stride == 1): + self.downsample = nn.Identity() + else: self.downsample = nn.Sequential( norm_layer(in_chs), - nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride), + nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride, bias=conv_bias), ) - else: - self.downsample = nn.Identity() - - dp_rates = dp_rates or [0.] * depth - self.blocks = nn.Sequential(*[ConvNeXtBlock( - dim=out_chs, drop_path=dp_rates[j], ls_init_value=ls_init_value, conv_mlp=conv_mlp, - norm_layer=norm_layer if conv_mlp else cl_norm_layer) - for j in range(depth)] - ) + in_chs = out_chs + + drop_path_rates = drop_path_rates or [0.] * depth + stage_blocks = [] + for i in range(depth): + stage_blocks.append(ConvNeXtBlock( + dim=in_chs, + dim_out=out_chs, + stride=stride if downsample_block and i == 0 else 1, + drop_path=drop_path_rates[i], + ls_init_value=ls_init_value, + conv_mlp=conv_mlp, + conv_bias=conv_bias, + norm_layer=norm_layer if conv_mlp else norm_layer_cl + )) + in_chs = out_chs + self.blocks = nn.Sequential(*stage_blocks) def forward(self, x): x = self.downsample(x) @@ -210,41 +223,57 @@ class ConvNeXt(nn.Module): """ def __init__( - self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4, - depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=False, stem_type='patch', - head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0., + self, + in_chans=3, + num_classes=1000, + global_pool='avg', + output_stride=32, + depths=(3, 3, 9, 3), + dims=(96, 192, 384, 768), + ls_init_value=1e-6, + stem_type='patch', + stem_kernel_size=4, + stem_stride=4, + head_init_scale=1., + head_norm_first=False, + downsample_block=False, + conv_mlp=False, + conv_bias=True, + norm_layer=None, + drop_rate=0., + drop_path_rate=0., ): super().__init__() assert output_stride == 32 if norm_layer is None: norm_layer = partial(LayerNorm2d, eps=1e-6) - cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) + norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) else: assert conv_mlp,\ 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' - cl_norm_layer = norm_layer + norm_layer_cl = norm_layer self.num_classes = num_classes self.drop_rate = drop_rate self.feature_info = [] - # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 + assert stem_type in ('patch', 'overlap') if stem_type == 'patch': + assert stem_kernel_size == stem_stride + # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 self.stem = nn.Sequential( - nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size), + nn.Conv2d(in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, bias=conv_bias), norm_layer(dims[0]) ) - curr_stride = patch_size - prev_chs = dims[0] else: self.stem = nn.Sequential( - nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1), - norm_layer(32), - nn.GELU(), - nn.Conv2d(32, 64, kernel_size=3, padding=1), + nn.Conv2d( + in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, + padding=stem_kernel_size // 2, bias=conv_bias), + norm_layer(dims[0]), ) - curr_stride = 2 - prev_chs = 64 + prev_chs = dims[0] + curr_stride = stem_stride self.stages = nn.Sequential() dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] @@ -256,16 +285,24 @@ class ConvNeXt(nn.Module): curr_stride *= stride out_chs = dims[i] stages.append(ConvNeXtStage( - prev_chs, out_chs, stride=stride, - depth=depths[i], dp_rates=dp_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp, - norm_layer=norm_layer, cl_norm_layer=cl_norm_layer) - ) + prev_chs, + out_chs, + stride=stride, + depth=depths[i], + drop_path_rates=dp_rates[i], + ls_init_value=ls_init_value, + downsample_block=downsample_block, + conv_mlp=conv_mlp, + conv_bias=conv_bias, + norm_layer=norm_layer, + norm_layer_cl=norm_layer_cl + )) prev_chs = out_chs # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) - self.num_features = prev_chs + # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() @@ -327,10 +364,11 @@ class ConvNeXt(nn.Module): def _init_weights(module, name=None, head_init_scale=1.0): if isinstance(module, nn.Conv2d): trunc_normal_(module.weight, std=.02) - nn.init.constant_(module.bias, 0) + if module.bias is not None: + nn.init.zeros_(module.bias) elif isinstance(module, nn.Linear): trunc_normal_(module.weight, std=.02) - nn.init.constant_(module.bias, 0) + nn.init.zeros_(module.bias) if name and 'head.' in name: module.weight.data.mul_(head_init_scale) module.bias.data.mul_(head_init_scale) @@ -371,11 +409,21 @@ def _create_convnext(variant, pretrained=False, **kwargs): @register_model def convnext_nano_hnf(pretrained=False, **kwargs): - model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs) + model_args = dict( + depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs) model = _create_convnext('convnext_nano_hnf', pretrained=pretrained, **model_args) return model +@register_model +def convnext_nano_ols(pretrained=False, **kwargs): + model_args = dict( + depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), downsample_block=True, + conv_bias=False, stem_type='overlap', stem_kernel_size=9, **kwargs) + model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args) + return model + + @register_model def convnext_tiny_hnf(pretrained=False, **kwargs): model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py new file mode 100644 index 00000000..0f8b0464 --- /dev/null +++ b/timm/models/edgenext.py @@ -0,0 +1,545 @@ +""" EdgeNeXt + +Paper: `EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications` + - https://arxiv.org/abs/2206.10589 + +Original code and weights from https://github.com/mmaaz60/EdgeNeXt + +Modifications and additions for timm by / Copyright 2022, Ross Wightman +""" +import math +import torch +from collections import OrderedDict +from functools import partial +from typing import Tuple + +from torch import nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import trunc_normal_tf_ +from timm.models.layers import DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d +from .helpers import named_apply, build_model_with_cfg, checkpoint_seq +from .registry import register_model + + +__all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), + 'crop_pct': 0.9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = dict( + edgenext_xx_small=_cfg( + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth"), + edgenext_x_small=_cfg( + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth"), + # edgenext_small=_cfg( + # url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_small.pth"), + edgenext_small=_cfg( # USI weights + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth", + crop_pct=0.95 + ), + + edgenext_small_rw=_cfg(), +) + + +class PositionalEncodingFourier(nn.Module): + def __init__(self, hidden_dim=32, dim=768, temperature=10000): + super().__init__() + self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1) + self.scale = 2 * math.pi + self.temperature = temperature + self.hidden_dim = hidden_dim + self.dim = dim + + def forward(self, shape: Tuple[int, int, int]): + inv_mask = ~torch.zeros(shape).to(device=self.token_projection.weight.device, dtype=torch.bool) + y_embed = inv_mask.cumsum(1, dtype=torch.float32) + x_embed = inv_mask.cumsum(2, dtype=torch.float32) + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=inv_mask.device) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), + pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), + pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + pos = self.token_projection(pos) + + return pos + + +class ConvBlock(nn.Module): + def __init__( + self, + dim, + dim_out=None, + kernel_size=7, + stride=1, + conv_bias=True, + expand_ratio=4, + ls_init_value=1e-6, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, drop_path=0., + ): + super().__init__() + dim_out = dim_out or dim + self.shortcut_after_dw = stride > 1 or dim != dim_out + + self.conv_dw = create_conv2d( + dim, dim_out, kernel_size=kernel_size, stride=stride, depthwise=True, bias=conv_bias) + self.norm = norm_layer(dim_out) + self.mlp = Mlp(dim_out, int(expand_ratio * dim_out), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + x = self.conv_dw(x) + if self.shortcut_after_dw: + shortcut = x + + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.mlp(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = shortcut + self.drop_path(x) + return x + + +class CrossCovarianceAttn(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0., + proj_drop=0. + ): + super().__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1) + q, k, v = qkv.unbind(0) + + # NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map + attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + @torch.jit.ignore + def no_weight_decay(self): + return {'temperature'} + + +class SplitTransposeBlock(nn.Module): + def __init__( + self, + dim, + num_scales=1, + num_heads=8, + expand_ratio=4, + use_pos_emb=True, + conv_bias=True, + qkv_bias=True, + ls_init_value=1e-6, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + drop_path=0., + attn_drop=0., + proj_drop=0. + ): + super().__init__() + width = max(int(math.ceil(dim / num_scales)), int(math.floor(dim // num_scales))) + self.width = width + self.num_scales = max(1, num_scales - 1) + + convs = [] + for i in range(self.num_scales): + convs.append(create_conv2d(width, width, kernel_size=3, depthwise=True, bias=conv_bias)) + self.convs = nn.ModuleList(convs) + + self.pos_embd = None + if use_pos_emb: + self.pos_embd = PositionalEncodingFourier(dim=dim) + self.norm_xca = norm_layer(dim) + self.gamma_xca = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None + self.xca = CrossCovarianceAttn( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop) + + self.norm = norm_layer(dim, eps=1e-6) + self.mlp = Mlp(dim, int(expand_ratio * dim), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + + # scales code re-written for torchscript as per my res2net fixes -rw + spx = torch.split(x, self.width, 1) + spo = [] + sp = spx[0] + for i, conv in enumerate(self.convs): + if i > 0: + sp = sp + spx[i] + sp = conv(sp) + spo.append(sp) + spo.append(spx[-1]) + x = torch.cat(spo, 1) + + # XCA + B, C, H, W = x.shape + x = x.reshape(B, C, H * W).permute(0, 2, 1) + if self.pos_embd is not None: + pos_encoding = self.pos_embd((B, H, W)).reshape(B, -1, x.shape[1]).permute(0, 2, 1) + x = x + pos_encoding + x = x + self.drop_path(self.gamma_xca * self.xca(self.norm_xca(x))) + x = x.reshape(B, H, W, C) + + # Inverted Bottleneck + x = self.norm(x) + x = self.mlp(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = shortcut + self.drop_path(x) + return x + + +class EdgeNeXtStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + stride=2, + depth=2, + num_global_blocks=1, + num_heads=4, + scales=2, + kernel_size=7, + expand_ratio=4, + use_pos_emb=False, + downsample_block=False, + conv_bias=True, + ls_init_value=1.0, + drop_path_rates=None, + norm_layer=LayerNorm2d, + norm_layer_cl=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU + ): + super().__init__() + self.grad_checkpointing = False + + if downsample_block or stride == 1: + self.downsample = nn.Identity() + else: + self.downsample = nn.Sequential( + norm_layer(in_chs), + nn.Conv2d(in_chs, out_chs, kernel_size=2, stride=2, bias=conv_bias) + ) + in_chs = out_chs + + stage_blocks = [] + for i in range(depth): + if i < depth - num_global_blocks: + stage_blocks.append( + ConvBlock( + dim=in_chs, + dim_out=out_chs, + stride=stride if downsample_block and i == 0 else 1, + conv_bias=conv_bias, + kernel_size=kernel_size, + expand_ratio=expand_ratio, + ls_init_value=ls_init_value, + drop_path=drop_path_rates[i], + norm_layer=norm_layer_cl, + act_layer=act_layer, + ) + ) + else: + stage_blocks.append( + SplitTransposeBlock( + dim=in_chs, + num_scales=scales, + num_heads=num_heads, + expand_ratio=expand_ratio, + use_pos_emb=use_pos_emb, + conv_bias=conv_bias, + ls_init_value=ls_init_value, + drop_path=drop_path_rates[i], + norm_layer=norm_layer_cl, + act_layer=act_layer, + ) + ) + in_chs = out_chs + self.blocks = nn.Sequential(*stage_blocks) + + def forward(self, x): + x = self.downsample(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class EdgeNeXt(nn.Module): + def __init__( + self, + in_chans=3, + num_classes=1000, + global_pool='avg', + dims=(24, 48, 88, 168), + depths=(3, 3, 9, 3), + global_block_counts=(0, 1, 1, 1), + kernel_sizes=(3, 5, 7, 9), + heads=(8, 8, 8, 8), + d2_scales=(2, 2, 3, 4), + use_pos_emb=(False, True, False, False), + ls_init_value=1e-6, + head_init_scale=1., + expand_ratio=4, + downsample_block=False, + conv_bias=True, + stem_type='patch', + head_norm_first=False, + act_layer=nn.GELU, + drop_path_rate=0., + drop_rate=0., + ): + super().__init__() + self.num_classes = num_classes + self.global_pool = global_pool + self.drop_rate = drop_rate + norm_layer = partial(LayerNorm2d, eps=1e-6) + norm_layer_cl = partial(nn.LayerNorm, eps=1e-6) + + assert stem_type in ('patch', 'overlap') + if stem_type == 'patch': + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, bias=conv_bias), + norm_layer(dims[0]), + ) + else: + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=9, stride=4, padding=9 // 2, bias=conv_bias), + norm_layer(dims[0]), + ) + + stages = [] + dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + in_chs = dims[0] + for i in range(4): + stages.append(EdgeNeXtStage( + in_chs=in_chs, + out_chs=dims[i], + stride=2 if i > 0 else 1, + depth=depths[i], + num_global_blocks=global_block_counts[i], + num_heads=heads[i], + drop_path_rates=dp_rates[i], + scales=d2_scales[i], + expand_ratio=expand_ratio, + kernel_size=kernel_sizes[i], + use_pos_emb=use_pos_emb[i], + ls_init_value=ls_init_value, + downsample_block=downsample_block, + conv_bias=conv_bias, + norm_layer=norm_layer, + norm_layer_cl=norm_layer_cl, + act_layer=act_layer, + )) + in_chs = dims[i] + self.stages = nn.Sequential(*stages) + + self.num_features = dims[-1] + self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() + self.head = nn.Sequential(OrderedDict([ + ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), + ('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)), + ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), + ('drop', nn.Dropout(self.drop_rate)), + ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) + + named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+)\.downsample', (0,)), # blocks + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + (r'^norm_pre', (99999,)) + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes=0, global_pool=None): + if global_pool is not None: + self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() + self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm_pre(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( + x = self.head.global_pool(x) + x = self.head.norm(x) + x = self.head.flatten(x) + x = self.head.drop(x) + return x if pre_logits else self.head.fc(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _init_weights(module, name=None, head_init_scale=1.0): + if isinstance(module, nn.Conv2d): + trunc_normal_tf_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Linear): + trunc_normal_tf_(module.weight, std=.02) + nn.init.zeros_(module.bias) + if name and 'head.' in name: + module.weight.data.mul_(head_init_scale) + module.bias.data.mul_(head_init_scale) + + +def checkpoint_filter_fn(state_dict, model): + """ Remap FB checkpoints -> timm """ + if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict: + return state_dict # non-FB checkpoint + + # models were released as train checkpoints... :/ + if 'model_ema' in state_dict: + state_dict = state_dict['model_ema'] + elif 'model' in state_dict: + state_dict = state_dict['model'] + elif 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + + out_dict = {} + import re + for k, v in state_dict.items(): + k = k.replace('downsample_layers.0.', 'stem.') + k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) + k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k) + k = k.replace('dwconv', 'conv_dw') + k = k.replace('pwconv', 'mlp.fc') + k = k.replace('head.', 'head.fc.') + if k.startswith('norm.'): + k = k.replace('norm', 'head.norm') + if v.ndim == 2 and 'head' not in k: + model_shape = model.state_dict()[k].shape + v = v.reshape(model_shape) + out_dict[k] = v + return out_dict + + +def _create_edgenext(variant, pretrained=False, **kwargs): + model = build_model_with_cfg( + EdgeNeXt, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), + **kwargs) + return model + + +@register_model +def edgenext_xx_small(pretrained=False, **kwargs): + # 1.33M & 260.58M @ 256 resolution + # 71.23% Top-1 accuracy + # No AA, Color Jitter=0.4, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=51.66 versus 47.67 for MobileViT_XXS + # For A100: FPS @ BS=1: 212.13 & @ BS=256: 7042.06 versus FPS @ BS=1: 96.68 & @ BS=256: 4624.71 for MobileViT_XXS + model_kwargs = dict(depths=(2, 2, 6, 2), dims=(24, 48, 88, 168), heads=(4, 4, 4, 4), **kwargs) + return _create_edgenext('edgenext_xx_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def edgenext_x_small(pretrained=False, **kwargs): + # 2.34M & 538.0M @ 256 resolution + # 75.00% Top-1 accuracy + # No AA, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=31.61 versus 28.49 for MobileViT_XS + # For A100: FPS @ BS=1: 179.55 & @ BS=256: 4404.95 versus FPS @ BS=1: 94.55 & @ BS=256: 2361.53 for MobileViT_XS + model_kwargs = dict(depths=(3, 3, 9, 3), dims=(32, 64, 100, 192), heads=(4, 4, 4, 4), **kwargs) + return _create_edgenext('edgenext_x_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def edgenext_small(pretrained=False, **kwargs): + # 5.59M & 1260.59M @ 256 resolution + # 79.43% Top-1 accuracy + # AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=20.47 versus 18.86 for MobileViT_S + # For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S + model_kwargs = dict(depths=(3, 3, 9, 3), dims=(48, 96, 160, 304), **kwargs) + return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def edgenext_small_rw(pretrained=False, **kwargs): + # 5.59M & 1260.59M @ 256 resolution + # 79.43% Top-1 accuracy + # AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=20.47 versus 18.86 for MobileViT_S + # For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S + model_kwargs = dict( + depths=(3, 3, 9, 3), dims=(48, 96, 192, 384), + downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs) + return _create_edgenext('edgenext_small_rw', pretrained=pretrained, **model_kwargs) + From 70d6d2c4847982a8f20c4233a28ba84ea9485868 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 2 Jul 2022 15:17:05 -0700 Subject: [PATCH 19/45] support test_crop_size in data config resolve --- timm/data/config.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/timm/data/config.py b/timm/data/config.py index 38f5689a..78176e4b 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -64,11 +64,15 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v new_config['std'] = default_cfg['std'] # resolve default crop percentage - new_config['crop_pct'] = DEFAULT_CROP_PCT + crop_pct = DEFAULT_CROP_PCT if 'crop_pct' in args and args['crop_pct'] is not None: - new_config['crop_pct'] = args['crop_pct'] - elif 'crop_pct' in default_cfg: - new_config['crop_pct'] = default_cfg['crop_pct'] + crop_pct = args['crop_pct'] + else: + if use_test_size and 'test_crop_pct' in default_cfg: + crop_pct = default_cfg['test_crop_pct'] + elif 'crop_pct' in default_cfg: + crop_pct = default_cfg['crop_pct'] + new_config['crop_pct'] = crop_pct if verbose: _logger.info('Data processing configuration for current model + dataset:') From 188c194b0f7bad1aa6c5db46e04c3ef63d2b10e6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 2 Jul 2022 15:17:28 -0700 Subject: [PATCH 20/45] Left some experiment stem code in convnext by mistake --- timm/models/convnext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 662695c7..138e5030 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -434,7 +434,7 @@ def convnext_tiny_hnf(pretrained=False, **kwargs): @register_model def convnext_tiny_hnfd(pretrained=False, **kwargs): model_args = dict( - depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, stem_type='dual', **kwargs) + depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args) return model From c170ba317318599e759d4f004e6ee6aebf1fc258 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 2 Jul 2022 15:18:06 -0700 Subject: [PATCH 21/45] Add weights for resnet10t, resnet14t, and resnetaa50 models. Fix #1314 --- timm/models/resnet.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 476ffe91..28f3cdba 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -35,6 +35,16 @@ def _cfg(url='', **kwargs): default_cfgs = { # ResNet and Wide ResNet + 'resnet10t': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet10t_176_c3-f3215ab1.pth', + input_size=(3, 176, 176), pool_size=(6, 6), + test_crop_pct=0.95, test_input_size=(3, 224, 224), + first_conv='conv1.0'), + 'resnet14t': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet14t_176_c3-c4ed2c37.pth', + input_size=(3, 176, 176), pool_size=(6, 6), + test_crop_pct=0.95, test_input_size=(3, 224, 224), + first_conv='conv1.0'), 'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'), 'resnet18d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet18d_ra2-48a79e06.pth', @@ -262,6 +272,10 @@ default_cfgs = { 'resnetblur101d': _cfg( url='', interpolation='bicubic', first_conv='conv1.0'), + 'resnetaa50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetaa50_a1h-4cf422b3.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0, + interpolation='bicubic', first_conv='conv1.0'), 'resnetaa50d': _cfg( url='', interpolation='bicubic', first_conv='conv1.0'), @@ -1454,6 +1468,14 @@ def resnetblur101d(pretrained=False, **kwargs): return _create_resnet('resnetblur101d', pretrained, **model_args) +@register_model +def resnetaa50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model with avgpool anti-aliasing + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, **kwargs) + return _create_resnet('resnetaa50', pretrained, **model_args) + + @register_model def resnetaa50d(pretrained=False, **kwargs): """Constructs a ResNet-50-D model with avgpool anti-aliasing From 377e9bfa217b60601fb6473022970f115a5455ca Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 2 Jul 2022 15:18:52 -0700 Subject: [PATCH 22/45] Add TPU trained darknet53 weights. Add mising pretrain_cfg for some csp/darknet models. --- timm/models/cspnet.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 095e4701..77473052 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -45,14 +45,18 @@ default_cfgs = { 'cspresnet50w': _cfg(url=''), 'cspresnext50': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnext50_ra_224-648b4713.pth', - input_size=(3, 224, 224), pool_size=(7, 7), crop_pct=0.875 # FIXME I trained this at 224x224, not 256 like ref impl ), 'cspdarknet53': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspdarknet53_ra_256-d05c7c21.pth'), 'darknet17': _cfg(url=''), 'darknet21': _cfg(url=''), - 'darknet53': _cfg(url=''), + 'sedarknet21': _cfg(url=''), + 'darknet53': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknet53_256_c2ns-3aeff817.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0, interpolation='bicubic' + ), + 'darknetaa53': _cfg(url=''), 'cs2darknet_m': _cfg( url=''), From dd9b8f57c4862d4edd87dc1e0a3b34ff005a27f4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 2 Jul 2022 15:20:45 -0700 Subject: [PATCH 23/45] Add feature_info to edgenext for features_only support, hopefully fix some fx / test errors --- timm/models/edgenext.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index 0f8b0464..97971ba6 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -17,8 +17,8 @@ from torch import nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.layers import trunc_normal_tf_ -from timm.models.layers import DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d +from .fx_features import register_notrace_module +from .layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d from .helpers import named_apply, build_model_with_cfg, checkpoint_seq from .registry import register_model @@ -53,6 +53,7 @@ default_cfgs = dict( ) +@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method class PositionalEncodingFourier(nn.Module): def __init__(self, hidden_dim=32, dim=768, temperature=10000): super().__init__() @@ -349,6 +350,7 @@ class EdgeNeXt(nn.Module): self.drop_rate = drop_rate norm_layer = partial(LayerNorm2d, eps=1e-6) norm_layer_cl = partial(nn.LayerNorm, eps=1e-6) + self.feature_info = [] assert stem_type in ('patch', 'overlap') if stem_type == 'patch': @@ -362,14 +364,18 @@ class EdgeNeXt(nn.Module): norm_layer(dims[0]), ) + curr_stride = 4 stages = [] dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] in_chs = dims[0] for i in range(4): + stride = 2 if curr_stride == 2 or i > 0 else 1 + # FIXME support dilation / output_stride + curr_stride *= stride stages.append(EdgeNeXtStage( in_chs=in_chs, out_chs=dims[i], - stride=2 if i > 0 else 1, + stride=stride, depth=depths[i], num_global_blocks=global_block_counts[i], num_heads=heads[i], @@ -385,7 +391,10 @@ class EdgeNeXt(nn.Module): norm_layer_cl=norm_layer_cl, act_layer=act_layer, )) + # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 in_chs = dims[i] + self.feature_info += [dict(num_chs=in_chs, reduction=curr_stride, module=f'stages.{i}')] + self.stages = nn.Sequential(*stages) self.num_features = dims[-1] From d76530582164740e65b6992148d2a755f16cde6b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 2 Jul 2022 15:56:17 -0700 Subject: [PATCH 24/45] Remove first_conv for resnetaa50 def --- timm/models/resnet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 28f3cdba..e5a6b791 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -274,8 +274,7 @@ default_cfgs = { interpolation='bicubic', first_conv='conv1.0'), 'resnetaa50': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetaa50_a1h-4cf422b3.pth', - test_input_size=(3, 288, 288), test_crop_pct=1.0, - interpolation='bicubic', first_conv='conv1.0'), + test_input_size=(3, 288, 288), test_crop_pct=1.0, interpolation='bicubic'), 'resnetaa50d': _cfg( url='', interpolation='bicubic', first_conv='conv1.0'), From d0c5bd57223c3f1da58219f497fe48d478f873da Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 3 Jul 2022 08:32:41 -0700 Subject: [PATCH 25/45] Rename cs2->cs3 for darknets. Fix features_only for cs3 darknets. --- timm/models/cspnet.py | 62 ++++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 77473052..4591f101 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -58,13 +58,13 @@ default_cfgs = { ), 'darknetaa53': _cfg(url=''), - 'cs2darknet_m': _cfg( + 'cs3darknet_m': _cfg( url=''), - 'cs2darknet_l': _cfg( + 'cs3darknet_l': _cfg( url=''), - 'cs2darknet_f_m': _cfg( + 'cs3darknet_focus_m': _cfg( url=''), - 'cs2darknet_f_l': _cfg( + 'cs3darknet_focus_l': _cfg( url=''), } @@ -185,7 +185,7 @@ model_cfgs = dict( ), ), - cs2darknet_m=dict( + cs3darknet_m=dict( stem=dict(out_chs=(24, 48), kernel_size=3, stride=2, pool=''), stage=dict( out_chs=(96, 192, 384, 768), @@ -196,12 +196,11 @@ model_cfgs = dict( avg_down=False, ), ), - - cs2darknet_f_m=dict( - stem=dict(out_chs=48, kernel_size=6, stride=2, padding=2, pool=''), + cs3darknet_l=dict( + stem=dict(out_chs=(32, 64), kernel_size=3, stride=2, pool=''), stage=dict( - out_chs=(96, 192, 384, 768), - depth=(2, 4, 6, 2), + out_chs=(128, 256, 512, 1024), + depth=(3, 6, 9, 3), stride=(2,) * 4, bottle_ratio=(1.,) * 4, block_ratio=(0.5,) * 4, @@ -209,19 +208,18 @@ model_cfgs = dict( ), ), - cs2darknet_l=dict( - stem=dict(out_chs=(32, 64), kernel_size=3, stride=2, pool=''), + cs3darknet_focus_m=dict( + stem=dict(out_chs=48, kernel_size=6, stride=2, padding=2, pool=''), stage=dict( - out_chs=(128, 256, 512, 1024), - depth=(3, 6, 9, 3), + out_chs=(96, 192, 384, 768), + depth=(2, 4, 6, 2), stride=(2,) * 4, bottle_ratio=(1.,) * 4, block_ratio=(0.5,) * 4, avg_down=False, ), ), - - cs2darknet_f_l=dict( + cs3darknet_focus_l=dict( stem=dict(out_chs=64, kernel_size=6, stride=2, padding=2, pool=''), stage=dict( out_chs=(128, 256, 512, 1024), @@ -438,9 +436,9 @@ class CrossStage(nn.Module): return out -class CrossStage2(nn.Module): - """Cross Stage v2. - Similar to CrossStage, but with one transition conv for the concat output. +class CrossStage3(nn.Module): + """Cross Stage 3. + Similar to CrossStage, but with only one transition conv for the output. """ def __init__( self, @@ -461,7 +459,7 @@ class CrossStage2(nn.Module): block_fn=ResBottleneck, **block_kwargs ): - super(CrossStage2, self).__init__() + super(CrossStage3, self).__init__() first_dilation = first_dilation or dilation down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels self.exp_chs = exp_chs = int(round(out_chs * exp_ratio)) @@ -696,8 +694,12 @@ def _init_weights(module, name, zero_init_last=False): def _create_cspnet(variant, pretrained=False, **kwargs): - # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5] - out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4, 5) if 'darknet' in variant else (0, 1, 2, 3, 4)) + if variant.startswith('darknet') or variant.startswith('cspdarknet'): + # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5] + default_out_indices = (0, 1, 2, 3, 4, 5) + else: + default_out_indices = (0, 1, 2, 3, 4) + out_indices = kwargs.pop('out_indices', default_out_indices) return build_model_with_cfg( CspNet, variant, pretrained, model_cfg=model_cfgs[variant], @@ -757,24 +759,24 @@ def darknetaa53(pretrained=False, **kwargs): @register_model -def cs2darknet_m(pretrained=False, **kwargs): +def cs3darknet_m(pretrained=False, **kwargs): return _create_cspnet( - 'cs2darknet_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs) + 'cs3darknet_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) @register_model -def cs2darknet_l(pretrained=False, **kwargs): +def cs3darknet_l(pretrained=False, **kwargs): return _create_cspnet( - 'cs2darknet_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs) + 'cs3darknet_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) @register_model -def cs2darknet_f_m(pretrained=False, **kwargs): +def cs3darknet_focus_m(pretrained=False, **kwargs): return _create_cspnet( - 'cs2darknet_f_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs) + 'cs3darknet_focus_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) @register_model -def cs2darknet_f_l(pretrained=False, **kwargs): +def cs3darknet_focus_l(pretrained=False, **kwargs): return _create_cspnet( - 'cs2darknet_f_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs) \ No newline at end of file + 'cs3darknet_focus_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) \ No newline at end of file From 7d4b3807d5c40b0f8d7e66d27a7672684e482996 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 4 Jul 2022 22:25:22 -0700 Subject: [PATCH 26/45] Support DeiT-3 (Revenge of the ViT) checkpoints. Add non-overlapping (w/ class token) pos-embed support to vit. --- timm/models/deit.py | 204 +++++++++++++++++++++++++++++- timm/models/vision_transformer.py | 64 +++++++--- 2 files changed, 247 insertions(+), 21 deletions(-) diff --git a/timm/models/deit.py b/timm/models/deit.py index e6b4b025..a2f43b91 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -1,7 +1,10 @@ """ DeiT - Data-efficient Image Transformers DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below -paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +paper: `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +paper: `DeiT III: Revenge of the ViT` - https://arxiv.org/abs/2204.07118 Modifications copyright 2021, Ross Wightman """ @@ -53,6 +56,46 @@ default_cfgs = { url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')), + + 'deit3_small_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'), + 'deit3_small_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_base_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'), + 'deit3_base_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_large_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'), + 'deit3_large_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_huge_patch14_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'), + + 'deit3_small_patch16_224_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth', + crop_pct=1.0), + 'deit3_small_patch16_384_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_base_patch16_224_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth', + crop_pct=1.0), + 'deit3_base_patch16_384_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_large_patch16_224_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth', + crop_pct=1.0), + 'deit3_large_patch16_384_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_huge_patch14_224_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth', + crop_pct=1.0), } @@ -68,9 +111,10 @@ class VisionTransformerDistilled(VisionTransformer): super().__init__(*args, **kwargs, weight_init='skip') assert self.global_pool in ('token',) - self.num_tokens = 2 + self.num_prefix_tokens = 2 self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim)) self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() self.distilled_training = False # must set this True to train w/ distillation token @@ -220,3 +264,157 @@ def deit_base_distilled_patch16_384(pretrained=False, **kwargs): model = _create_deit( 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) return model + + +@register_model +def deit3_small_patch16_224(pretrained=False, **kwargs): + """ DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_small_patch16_384(pretrained=False, **kwargs): + """ DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_small_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_base_patch16_224(pretrained=False, **kwargs): + """ DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_base_patch16_384(pretrained=False, **kwargs): + """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_base_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_large_patch16_224(pretrained=False, **kwargs): + """ DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_large_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_large_patch16_384(pretrained=False, **kwargs): + """ DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_large_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_huge_patch14_224(pretrained=False, **kwargs): + """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_huge_patch14_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_small_patch16_224_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_small_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_small_patch16_384_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_small_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_base_patch16_224_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_base_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_base_patch16_384_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_base_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_large_patch16_224_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_large_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_large_patch16_384_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_large_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_huge_patch14_224_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_huge_patch14_224_in21ft1k', pretrained=pretrained, **model_kwargs) + return model diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 8551feae..022052d0 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -325,8 +325,8 @@ class VisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, - class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', - embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): + class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): """ Args: img_size (int, tuple): input image size @@ -360,15 +360,17 @@ class VisionTransformer(nn.Module): self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 1 if class_token else 0 + self.num_prefix_tokens = 1 if class_token else 0 + self.no_embed_class = no_embed_class self.grad_checkpointing = False self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None - self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule @@ -428,11 +430,24 @@ class VisionTransformer(nn.Module): self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + def _pos_embed(self, x): + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + self.pos_embed + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.pos_embed + return self.pos_drop(x) + def forward_features(self, x): x = self.patch_embed(x) - if self.cls_token is not None: - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - x = self.pos_drop(x + self.pos_embed) + x = self._pos_embed(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: @@ -442,7 +457,7 @@ class VisionTransformer(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool: - x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) return x if pre_logits else self.head(x) @@ -556,7 +571,11 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) if pos_embed_w.shape != model.pos_embed.shape: pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights - pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + pos_embed_w, + model.pos_embed, + getattr(model, 'num_prefix_tokens', 1), + model.patch_embed.grid_size + ) model.pos_embed.copy_(pos_embed_w) model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) @@ -585,16 +604,16 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) -def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): +def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) ntok_new = posemb_new.shape[1] - if num_tokens: - posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] - ntok_new -= num_tokens + if num_prefix_tokens: + posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:] + ntok_new -= num_prefix_tokens else: - posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] gs_old = int(math.sqrt(len(posemb_grid))) if not len(gs_new): # backwards compatibility gs_new = [int(math.sqrt(ntok_new))] * 2 @@ -603,25 +622,34 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) - posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) return posemb def checkpoint_filter_fn(state_dict, model): """ convert patch embedding weight from manual patchify + linear proj to conv""" + import re out_dict = {} if 'model' in state_dict: # For deit models state_dict = state_dict['model'] + for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k and len(v.shape) < 4: # For old models that I trained prior to conv based patchification O, I, H, W = model.patch_embed.proj.weight.shape v = v.reshape(O, -1, H, W) - elif k == 'pos_embed' and v.shape != model.pos_embed.shape: + elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: # To resize pos embedding when using model at different size from pretrained weights v = resize_pos_embed( - v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + v, + model.pos_embed, + getattr(model, 'num_prefix_tokens', 1), + model.patch_embed.grid_size + ) + elif 'gamma_' in k: + # remap layer-scale gamma into sub-module (deit3 models) + k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k) elif 'pre_logits' in k: # NOTE representation layer removed as not used in latest 21k/1k pretrained weights continue From bfc0dccb0ed1026f596797818ab865ea53ef3d2c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 14:23:20 -0700 Subject: [PATCH 27/45] Improve image extension handling, add methods to modify / get defaults. Fix #1335 fix #1274. --- timm/data/__init__.py | 5 ++- timm/data/parsers/__init__.py | 1 + timm/data/parsers/constants.py | 1 - timm/data/parsers/img_extensions.py | 50 ++++++++++++++++++++++++ timm/data/parsers/parser_factory.py | 1 - timm/data/parsers/parser_image_folder.py | 29 ++++++++++++-- timm/data/parsers/parser_image_in_tar.py | 29 ++++++++------ timm/data/parsers/parser_image_tar.py | 10 +++-- 8 files changed, 103 insertions(+), 23 deletions(-) delete mode 100644 timm/data/parsers/constants.py create mode 100644 timm/data/parsers/img_extensions.py diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 7d3cb2b4..0eb10a66 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -6,7 +6,8 @@ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset_factory import create_dataset from .loader import create_loader from .mixup import Mixup, FastCollateMixup -from .parsers import create_parser +from .parsers import create_parser,\ + get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions from .real_labels import RealLabelsImagenet from .transforms import * -from .transforms_factory import create_transform \ No newline at end of file +from .transforms_factory import create_transform diff --git a/timm/data/parsers/__init__.py b/timm/data/parsers/__init__.py index eeb44e37..4e820d5e 100644 --- a/timm/data/parsers/__init__.py +++ b/timm/data/parsers/__init__.py @@ -1 +1,2 @@ from .parser_factory import create_parser +from .img_extensions import * diff --git a/timm/data/parsers/constants.py b/timm/data/parsers/constants.py deleted file mode 100644 index e7ba484e..00000000 --- a/timm/data/parsers/constants.py +++ /dev/null @@ -1 +0,0 @@ -IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') diff --git a/timm/data/parsers/img_extensions.py b/timm/data/parsers/img_extensions.py new file mode 100644 index 00000000..45c85aab --- /dev/null +++ b/timm/data/parsers/img_extensions.py @@ -0,0 +1,50 @@ +from copy import deepcopy + +__all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions'] + + +IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use +_IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync + + +def _set_extensions(extensions): + global IMG_EXTENSIONS + global _IMG_EXTENSIONS_SET + dedupe = set() # NOTE de-duping tuple while keeping original order + IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x)) + _IMG_EXTENSIONS_SET = set(extensions) + + +def _valid_extension(x: str): + return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.') + + +def is_img_extension(ext): + return ext in _IMG_EXTENSIONS_SET + + +def get_img_extensions(as_set=False): + return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS) + + +def set_img_extensions(extensions): + assert len(extensions) + for x in extensions: + assert _valid_extension(x) + _set_extensions(extensions) + + +def add_img_extensions(ext): + if not isinstance(ext, (list, tuple, set)): + ext = (ext,) + for x in ext: + assert _valid_extension(x) + extensions = IMG_EXTENSIONS + tuple(ext) + _set_extensions(extensions) + + +def del_img_extensions(ext): + if not isinstance(ext, (list, tuple, set)): + ext = (ext,) + extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext) + _set_extensions(extensions) diff --git a/timm/data/parsers/parser_factory.py b/timm/data/parsers/parser_factory.py index 892090ad..0665c02a 100644 --- a/timm/data/parsers/parser_factory.py +++ b/timm/data/parsers/parser_factory.py @@ -1,7 +1,6 @@ import os from .parser_image_folder import ParserImageFolder -from .parser_image_tar import ParserImageTar from .parser_image_in_tar import ParserImageInTar diff --git a/timm/data/parsers/parser_image_folder.py b/timm/data/parsers/parser_image_folder.py index ed349009..3d22a17b 100644 --- a/timm/data/parsers/parser_image_folder.py +++ b/timm/data/parsers/parser_image_folder.py @@ -6,15 +6,35 @@ on the folder hierarchy, just leaf folders by default. Hacked together by / Copyright 2020 Ross Wightman """ import os +from typing import Dict, List, Optional, Set, Tuple, Union from timm.utils.misc import natural_key -from .parser import Parser from .class_map import load_class_map -from .constants import IMG_EXTENSIONS +from .img_extensions import get_img_extensions +from .parser import Parser + + +def find_images_and_targets( + folder: str, + types: Optional[Union[List, Tuple, Set]] = None, + class_to_idx: Optional[Dict] = None, + leaf_name_only: bool = True, + sort: bool = True +): + """ Walk folder recursively to discover images and map them to classes by folder names. + Args: + folder: root of folder to recrusively search + types: types (file extensions) to search for in path + class_to_idx: specify mapping for class (folder name) to class index if set + leaf_name_only: use only leaf-name of folder walk for class names + sort: re-sort found images by name (for consistent ordering) -def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): + Returns: + A list of image and target tuples, class_to_idx mapping + """ + types = get_img_extensions(as_set=True) if not types else set(types) labels = [] filenames = [] for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): @@ -51,7 +71,8 @@ class ParserImageFolder(Parser): self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) if len(self.samples) == 0: raise RuntimeError( - f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}') + f'Found 0 images in subfolders of {root}. ' + f'Supported image extensions are {", ".join(get_img_extensions())}') def __getitem__(self, index): path, target = self.samples[index] diff --git a/timm/data/parsers/parser_image_in_tar.py b/timm/data/parsers/parser_image_in_tar.py index c6ada962..4fcad797 100644 --- a/timm/data/parsers/parser_image_in_tar.py +++ b/timm/data/parsers/parser_image_in_tar.py @@ -9,20 +9,20 @@ Labels are based on the combined folder and/or tar name structure. Hacked together by / Copyright 2020 Ross Wightman """ +import logging import os -import tarfile import pickle -import logging -import numpy as np +import tarfile from glob import glob -from typing import List, Dict +from typing import List, Tuple, Dict, Set, Optional, Union + +import numpy as np from timm.utils.misc import natural_key -from .parser import Parser from .class_map import load_class_map -from .constants import IMG_EXTENSIONS - +from .img_extensions import get_img_extensions +from .parser import Parser _logger = logging.getLogger(__name__) CACHE_FILENAME_SUFFIX = '_tarinfos.pickle' @@ -39,7 +39,7 @@ class TarState: self.tf = None -def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS): +def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions: Set[str]): sample_count = 0 for i, ti in enumerate(tf): if not ti.isfile(): @@ -60,7 +60,14 @@ def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTE return sample_count -def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True): +def extract_tarinfos( + root, + class_name_to_idx: Optional[Dict] = None, + cache_tarinfo: Optional[bool] = None, + extensions: Optional[Union[List, Tuple, Set]] = None, + sort: bool = True +): + extensions = get_img_extensions(as_set=True) if not extensions else set(extensions) root_is_tar = False if os.path.isfile(root): assert os.path.splitext(root)[-1].lower() == '.tar' @@ -176,8 +183,8 @@ class ParserImageInTar(Parser): self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos( self.root, class_name_to_idx=class_name_to_idx, - cache_tarinfo=cache_tarinfo, - extensions=IMG_EXTENSIONS) + cache_tarinfo=cache_tarinfo + ) self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()} if len(tarfiles) == 1 and tarfiles[0][0] is None: self.root_is_tar = True diff --git a/timm/data/parsers/parser_image_tar.py b/timm/data/parsers/parser_image_tar.py index 467537f4..c2ed429d 100644 --- a/timm/data/parsers/parser_image_tar.py +++ b/timm/data/parsers/parser_image_tar.py @@ -8,13 +8,15 @@ Hacked together by / Copyright 2020 Ross Wightman import os import tarfile -from .parser import Parser -from .class_map import load_class_map -from .constants import IMG_EXTENSIONS from timm.utils.misc import natural_key +from .class_map import load_class_map +from .img_extensions import get_img_extensions +from .parser import Parser + def extract_tarinfo(tarfile, class_to_idx=None, sort=True): + extensions = get_img_extensions(as_set=True) files = [] labels = [] for ti in tarfile.getmembers(): @@ -23,7 +25,7 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True): dirname, basename = os.path.split(ti.path) label = os.path.basename(dirname) ext = os.path.splitext(basename)[1] - if ext.lower() in IMG_EXTENSIONS: + if ext.lower() in extensions: files.append(ti) labels.append(label) if class_to_idx is None: From 06307b8b41da5783f38167f8ab609f83fb6b351d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 14:37:58 -0700 Subject: [PATCH 28/45] Remove experimental downsample in block support in ConvNeXt. Experiment further before keeping it in. --- timm/models/convnext.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 138e5030..be0c9a66 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -17,7 +17,6 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_module from .helpers import named_apply, build_model_with_cfg, checkpoint_seq from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d from .registry import register_model @@ -124,7 +123,6 @@ class ConvNeXtBlock(nn.Module): norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) mlp_layer = ConvMlp if conv_mlp else Mlp self.use_conv_mlp = conv_mlp - self.shortcut_after_dw = stride > 1 self.conv_dw = create_conv2d(dim, dim_out, kernel_size=7, stride=stride, depthwise=True, bias=conv_bias) self.norm = norm_layer(dim_out) @@ -135,9 +133,6 @@ class ConvNeXtBlock(nn.Module): def forward(self, x): shortcut = x x = self.conv_dw(x) - if self.shortcut_after_dw: - shortcut = x - if self.use_conv_mlp: x = self.norm(x) x = self.mlp(x) @@ -150,7 +145,6 @@ class ConvNeXtBlock(nn.Module): x = x.mul(self.gamma.reshape(1, -1, 1, 1)) x = self.drop_path(x) + shortcut - #print('b', x.shape) return x @@ -164,7 +158,6 @@ class ConvNeXtStage(nn.Module): depth=2, drop_path_rates=None, ls_init_value=1.0, - downsample_block=False, conv_mlp=False, conv_bias=True, norm_layer=None, @@ -173,14 +166,14 @@ class ConvNeXtStage(nn.Module): super().__init__() self.grad_checkpointing = False - if downsample_block or (in_chs == out_chs and stride == 1): - self.downsample = nn.Identity() - else: + if in_chs != out_chs or stride > 1: self.downsample = nn.Sequential( norm_layer(in_chs), nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride, bias=conv_bias), ) in_chs = out_chs + else: + self.downsample = nn.Identity() drop_path_rates = drop_path_rates or [0.] * depth stage_blocks = [] @@ -188,7 +181,6 @@ class ConvNeXtStage(nn.Module): stage_blocks.append(ConvNeXtBlock( dim=in_chs, dim_out=out_chs, - stride=stride if downsample_block and i == 0 else 1, drop_path=drop_path_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp, @@ -236,7 +228,6 @@ class ConvNeXt(nn.Module): stem_stride=4, head_init_scale=1., head_norm_first=False, - downsample_block=False, conv_mlp=False, conv_bias=True, norm_layer=None, @@ -291,7 +282,6 @@ class ConvNeXt(nn.Module): depth=depths[i], drop_path_rates=dp_rates[i], ls_init_value=ls_init_value, - downsample_block=downsample_block, conv_mlp=conv_mlp, conv_bias=conv_bias, norm_layer=norm_layer, @@ -418,7 +408,7 @@ def convnext_nano_hnf(pretrained=False, **kwargs): @register_model def convnext_nano_ols(pretrained=False, **kwargs): model_args = dict( - depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), downsample_block=True, + depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, conv_bias=False, stem_type='overlap', stem_kernel_size=9, **kwargs) model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args) return model @@ -426,7 +416,8 @@ def convnext_nano_ols(pretrained=False, **kwargs): @register_model def convnext_tiny_hnf(pretrained=False, **kwargs): - model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) + model_args = dict( + depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args) return model From eca09b86423d8f441e55f27205efa1b3c9e77d41 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 14:41:01 -0700 Subject: [PATCH 29/45] Add MobileVitV2 support. Fix #1332. Move GroupNorm1 to common layers (used in poolformer + mobilevitv2). Keep ol custom ConvNeXt LayerNorm2d impl as LayerNormExp2d for reference. --- timm/models/layers/__init__.py | 2 +- timm/models/layers/create_attn.py | 2 +- timm/models/layers/norm.py | 54 +++- timm/models/mobilevit.py | 443 +++++++++++++++++++++++++++++- timm/models/poolformer.py | 11 +- 5 files changed, 489 insertions(+), 23 deletions(-) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index b1f452ff..b9eeec0f 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -25,7 +25,7 @@ from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn -from .norm import GroupNorm, LayerNorm2d +from .norm import GroupNorm, GroupNorm1, LayerNorm2d from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm from .padding import get_padding, get_same_padding, pad_same from .patch_embed import PatchEmbed diff --git a/timm/models/layers/create_attn.py b/timm/models/layers/create_attn.py index 028c0f75..cc7e91ea 100644 --- a/timm/models/layers/create_attn.py +++ b/timm/models/layers/create_attn.py @@ -22,7 +22,7 @@ def get_attn(attn_type): if isinstance(attn_type, torch.nn.Module): return attn_type module_cls = None - if attn_type is not None: + if attn_type: if isinstance(attn_type, str): attn_type = attn_type.lower() # Lightweight attention modules (channel and/or coarse spatial). diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index 345f67bc..1677dbfa 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -14,11 +14,59 @@ class GroupNorm(nn.GroupNorm): return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) +class GroupNorm1(nn.GroupNorm): + """ Group Normalization with 1 group. + Input: tensor in shape [B, C, *] + """ + + def __init__(self, num_channels, **kwargs): + super().__init__(1, num_channels, **kwargs) + + class LayerNorm2d(nn.LayerNorm): - """ LayerNorm for channels of '2D' spatial BCHW tensors """ - def __init__(self, num_channels, eps=1e-6): - super().__init__(num_channels, eps=eps) + """ LayerNorm for channels of '2D' spatial NCHW tensors """ + def __init__(self, num_channels, eps=1e-6, affine=True): + super().__init__(num_channels, eps=eps, elementwise_affine=affine) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.layer_norm( x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + + +def _is_contiguous(tensor: torch.Tensor) -> bool: + # jit is oh so lovely :/ + # if torch.jit.is_tracing(): + # return True + if torch.jit.is_scripting(): + return tensor.is_contiguous() + else: + return tensor.is_contiguous(memory_format=torch.contiguous_format) + + +@torch.jit.script +def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): + s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) + x = (x - u) * torch.rsqrt(s + eps) + x = x * weight[:, None, None] + bias[:, None, None] + return x + + +class LayerNormExp2d(nn.LayerNorm): + """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). + + Experimental implementation w/ manual norm for tensors non-contiguous tensors. + + This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last + layout. However, benefits are not always clear and can perform worse on other GPUs. + """ + + def __init__(self, num_channels, eps=1e-6): + super().__init__(num_channels, eps=eps) + + def forward(self, x) -> torch.Tensor: + if _is_contiguous(x): + x = F.layer_norm( + x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + else: + x = _layer_norm_cf(x, self.weight, self.bias, self.eps) + return x diff --git a/timm/models/mobilevit.py b/timm/models/mobilevit.py index 1c55bd1c..2a3ab924 100644 --- a/timm/models/mobilevit.py +++ b/timm/models/mobilevit.py @@ -1,7 +1,8 @@ """ MobileViT Paper: -`MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178 +V1: `MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178 +V2: `Separable Self-attention for Mobile Vision Transformers` - https://arxiv.org/abs/2206.02680 MobileVitBlock and checkpoints adapted from https://github.com/apple/ml-cvnets (original copyright below) License: https://github.com/apple/ml-cvnets/blob/main/LICENSE (Apple open source) @@ -13,7 +14,7 @@ Rest of code, ByobNet, and Transformer block hacked together by / Copyright 2022 # Copyright (C) 2020 Apple Inc. All Rights Reserved. # import math -from typing import Union, Callable, Dict, Tuple, Optional +from typing import Union, Callable, Dict, Tuple, Optional, Sequence import torch from torch import nn @@ -21,7 +22,7 @@ import torch.nn.functional as F from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups from .fx_features import register_notrace_module -from .layers import to_2tuple, make_divisible +from .layers import to_2tuple, make_divisible, LayerNorm2d, GroupNorm1, ConvMlp, DropPath from .vision_transformer import Block as TransformerBlock from .helpers import build_model_with_cfg from .registry import register_model @@ -48,6 +49,48 @@ default_cfgs = { 'mobilevit_s': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_s-38a5a959.pth'), 'semobilevit_s': _cfg(), + + 'mobilevitv2_050': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_050-49951ee2.pth', + crop_pct=0.888), + 'mobilevitv2_075': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_075-b5556ef6.pth', + crop_pct=0.888), + 'mobilevitv2_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_100-e464ef3b.pth', + crop_pct=0.888), + 'mobilevitv2_125': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_125-0ae35027.pth', + crop_pct=0.888), + 'mobilevitv2_150': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150-737c5019.pth', + crop_pct=0.888), + 'mobilevitv2_175': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175-16462ee2.pth', + crop_pct=0.888), + 'mobilevitv2_200': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200-b3422f67.pth', + crop_pct=0.888), + + 'mobilevitv2_150_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_in22ft1k-0b555d7b.pth', + crop_pct=0.888), + 'mobilevitv2_175_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_in22ft1k-4117fa1f.pth', + crop_pct=0.888), + 'mobilevitv2_200_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_in22ft1k-1d7c8927.pth', + crop_pct=0.888), + + 'mobilevitv2_150_384_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_384_in22ft1k-9e142854.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'mobilevitv2_175_384_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_384_in22ft1k-059cbe56.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'mobilevitv2_200_384_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_384_in22ft1k-32c87503.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), } @@ -72,6 +115,40 @@ def _mobilevit_block(d, c, s, transformer_dim, transformer_depth, patch_size=4, ) +def _mobilevitv2_block(d, c, s, transformer_depth, patch_size=2, br=2.0, transformer_br=0.5): + # inverted residual + mobilevit blocks as per MobileViT network + return ( + _inverted_residual_block(d=d, c=c, s=s, br=br), + ByoBlockCfg( + type='mobilevit2', d=1, c=c, s=1, br=transformer_br, gs=1, + block_kwargs=dict( + transformer_depth=transformer_depth, + patch_size=patch_size) + ) + ) + + +def _mobilevitv2_cfg(multiplier=1.0): + chs = (64, 128, 256, 384, 512) + if multiplier != 1.0: + chs = tuple([int(c * multiplier) for c in chs]) + cfg = ByoModelCfg( + blocks=( + _inverted_residual_block(d=1, c=chs[0], s=1, br=2.0), + _inverted_residual_block(d=2, c=chs[1], s=2, br=2.0), + _mobilevitv2_block(d=1, c=chs[2], s=2, transformer_depth=2), + _mobilevitv2_block(d=1, c=chs[3], s=2, transformer_depth=4), + _mobilevitv2_block(d=1, c=chs[4], s=2, transformer_depth=3), + ), + stem_chs=int(32 * multiplier), + stem_type='3x3', + stem_pool='', + downsample='', + act_layer='silu', + ) + return cfg + + model_cfgs = dict( mobilevit_xxs=ByoModelCfg( blocks=( @@ -137,11 +214,19 @@ model_cfgs = dict( attn_kwargs=dict(rd_ratio=1/8), num_features=640, ), + + mobilevitv2_050=_mobilevitv2_cfg(.50), + mobilevitv2_075=_mobilevitv2_cfg(.75), + mobilevitv2_125=_mobilevitv2_cfg(1.25), + mobilevitv2_100=_mobilevitv2_cfg(1.0), + mobilevitv2_150=_mobilevitv2_cfg(1.5), + mobilevitv2_175=_mobilevitv2_cfg(1.75), + mobilevitv2_200=_mobilevitv2_cfg(2.0), ) @register_notrace_module -class MobileViTBlock(nn.Module): +class MobileVitBlock(nn.Module): """ MobileViT block Paper: https://arxiv.org/abs/2110.02178?context=cs.LG """ @@ -165,9 +250,9 @@ class MobileViTBlock(nn.Module): drop_path_rate: float = 0., layers: LayerFn = None, transformer_norm_layer: Callable = nn.LayerNorm, - downsample: str = '' + **kwargs, # eat unused args ): - super(MobileViTBlock, self).__init__() + super(MobileVitBlock, self).__init__() layers = layers or LayerFn() groups = num_groups(group_size, in_chs) @@ -241,7 +326,270 @@ class MobileViTBlock(nn.Module): return x -register_block('mobilevit', MobileViTBlock) +class LinearSelfAttention(nn.Module): + """ + This layer applies a self-attention with linear complexity, as described in `https://arxiv.org/abs/2206.02680` + This layer can be used for self- as well as cross-attention. + Args: + embed_dim (int): :math:`C` from an expected input of size :math:`(N, C, H, W)` + attn_drop (float): Dropout value for context scores. Default: 0.0 + bias (bool): Use bias in learnable layers. Default: True + Shape: + - Input: :math:`(N, C, P, N)` where :math:`N` is the batch size, :math:`C` is the input channels, + :math:`P` is the number of pixels in the patch, and :math:`N` is the number of patches + - Output: same as the input + .. note:: + For MobileViTv2, we unfold the feature map [B, C, H, W] into [B, C, P, N] where P is the number of pixels + in a patch and N is the number of patches. Because channel is the first dimension in this unfolded tensor, + we use point-wise convolution (instead of a linear layer). This avoids a transpose operation (which may be + expensive on resource-constrained devices) that may be required to convert the unfolded tensor from + channel-first to channel-last format in case of a linear layer. + """ + + def __init__( + self, + embed_dim: int, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + + self.qkv_proj = nn.Conv2d( + in_channels=embed_dim, + out_channels=1 + (2 * embed_dim), + bias=bias, + kernel_size=1, + ) + self.attn_drop = nn.Dropout(attn_drop) + self.out_proj = nn.Conv2d( + in_channels=embed_dim, + out_channels=embed_dim, + bias=bias, + kernel_size=1, + ) + self.out_drop = nn.Dropout(proj_drop) + + def _forward_self_attn(self, x: torch.Tensor) -> torch.Tensor: + # [B, C, P, N] --> [B, h + 2d, P, N] + qkv = self.qkv_proj(x) + + # Project x into query, key and value + # Query --> [B, 1, P, N] + # value, key --> [B, d, P, N] + query, key, value = qkv.split([1, self.embed_dim, self.embed_dim], dim=1) + + # apply softmax along N dimension + context_scores = F.softmax(query, dim=-1) + context_scores = self.attn_drop(context_scores) + + # Compute context vector + # [B, d, P, N] x [B, 1, P, N] -> [B, d, P, N] --> [B, d, P, 1] + context_vector = (key * context_scores).sum(dim=-1, keepdim=True) + + # combine context vector with values + # [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N] + out = F.relu(value) * context_vector.expand_as(value) + out = self.out_proj(out) + out = self.out_drop(out) + return out + + @torch.jit.ignore() + def _forward_cross_attn(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor: + # x --> [B, C, P, N] + # x_prev = [B, C, P, M] + batch_size, in_dim, kv_patch_area, kv_num_patches = x.shape + q_patch_area, q_num_patches = x.shape[-2:] + + assert ( + kv_patch_area == q_patch_area + ), "The number of pixels in a patch for query and key_value should be the same" + + # compute query, key, and value + # [B, C, P, M] --> [B, 1 + d, P, M] + qk = F.conv2d( + x_prev, + weight=self.qkv_proj.weight[:self.embed_dim + 1], + bias=self.qkv_proj.bias[:self.embed_dim + 1], + ) + + # [B, 1 + d, P, M] --> [B, 1, P, M], [B, d, P, M] + query, key = qk.split([1, self.embed_dim], dim=1) + # [B, C, P, N] --> [B, d, P, N] + value = F.conv2d( + x, + weight=self.qkv_proj.weight[self.embed_dim + 1], + bias=self.qkv_proj.bias[self.embed_dim + 1] if self.qkv_proj.bias is not None else None, + ) + + # apply softmax along M dimension + context_scores = F.softmax(query, dim=-1) + context_scores = self.attn_drop(context_scores) + + # compute context vector + # [B, d, P, M] * [B, 1, P, M] -> [B, d, P, M] --> [B, d, P, 1] + context_vector = (key * context_scores).sum(dim=-1, keepdim=True) + + # combine context vector with values + # [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N] + out = F.relu(value) * context_vector.expand_as(value) + out = self.out_proj(out) + out = self.out_drop(out) + return out + + def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor: + if x_prev is None: + return self._forward_self_attn(x) + else: + return self._forward_cross_attn(x, x_prev=x_prev) + + +class LinearTransformerBlock(nn.Module): + """ + This class defines the pre-norm transformer encoder with linear self-attention in `MobileViTv2 paper <>`_ + Args: + embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, P, N)` + mlp_ratio (float): Inner dimension ratio of the FFN relative to embed_dim + drop (float): Dropout rate. Default: 0.0 + attn_drop (float): Dropout rate for attention in multi-head attention. Default: 0.0 + drop_path (float): Stochastic depth rate Default: 0.0 + norm_layer (Callable): Normalization layer. Default: layer_norm_2d + Shape: + - Input: :math:`(B, C_{in}, P, N)` where :math:`B` is batch size, :math:`C_{in}` is input embedding dim, + :math:`P` is number of pixels in a patch, and :math:`N` is number of patches, + - Output: same shape as the input + """ + + def __init__( + self, + embed_dim: int, + mlp_ratio: float = 2.0, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + act_layer=None, + norm_layer=None, + ) -> None: + super().__init__() + act_layer = act_layer or nn.SiLU + norm_layer = norm_layer or GroupNorm1 + + self.norm1 = norm_layer(embed_dim) + self.attn = LinearSelfAttention(embed_dim=embed_dim, attn_drop=attn_drop, proj_drop=drop) + self.drop_path1 = DropPath(drop_path) + + self.norm2 = norm_layer(embed_dim) + self.mlp = ConvMlp( + in_features=embed_dim, + hidden_features=int(embed_dim * mlp_ratio), + act_layer=act_layer, + drop=drop) + self.drop_path2 = DropPath(drop_path) + + def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor: + if x_prev is None: + # self-attention + x = x + self.drop_path1(self.attn(self.norm1(x))) + else: + # cross-attention + res = x + x = self.norm1(x) # norm + x = self.attn(x, x_prev) # attn + x = self.drop_path1(x) + res # residual + + # Feed forward network + x = x + self.drop_path2(self.mlp(self.norm2(x))) + return x + + +@register_notrace_module +class MobileVitV2Block(nn.Module): + """ + This class defines the `MobileViTv2 block <>`_ + """ + + def __init__( + self, + in_chs: int, + out_chs: Optional[int] = None, + kernel_size: int = 3, + bottle_ratio: float = 1.0, + group_size: Optional[int] = 1, + dilation: Tuple[int, int] = (1, 1), + mlp_ratio: float = 2.0, + transformer_dim: Optional[int] = None, + transformer_depth: int = 2, + patch_size: int = 8, + attn_drop: float = 0., + drop: int = 0., + drop_path_rate: float = 0., + layers: LayerFn = None, + transformer_norm_layer: Callable = GroupNorm1, + **kwargs, # eat unused args + ): + super(MobileVitV2Block, self).__init__() + layers = layers or LayerFn() + groups = num_groups(group_size, in_chs) + out_chs = out_chs or in_chs + transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs) + + self.conv_kxk = layers.conv_norm_act( + in_chs, in_chs, kernel_size=kernel_size, + stride=1, groups=groups, dilation=dilation[0]) + self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False) + + self.transformer = nn.Sequential(*[ + LinearTransformerBlock( + transformer_dim, + mlp_ratio=mlp_ratio, + attn_drop=attn_drop, + drop=drop, + drop_path=drop_path_rate, + act_layer=layers.act, + norm_layer=transformer_norm_layer + ) + for _ in range(transformer_depth) + ]) + self.norm = transformer_norm_layer(transformer_dim) + + self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1, apply_act=False) + + self.patch_size = to_2tuple(patch_size) + self.patch_area = self.patch_size[0] * self.patch_size[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, H, W = x.shape + patch_h, patch_w = self.patch_size + new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w + num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w + num_patches = num_patch_h * num_patch_w # N + if new_h != H or new_w != W: + x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=True) + + # Local representation + x = self.conv_kxk(x) + x = self.conv_1x1(x) + + # Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N] + C = x.shape[1] + x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4) + x = x.reshape(B, C, -1, num_patches) + + # Global representations + x = self.transformer(x) + x = self.norm(x) + + # Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W] + x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3) + x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w) + + x = self.conv_proj(x) + return x + + +register_block('mobilevit', MobileVitBlock) +register_block('mobilevit2', MobileVitV2Block) def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs): @@ -252,6 +600,14 @@ def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs): **kwargs) +def _create_mobilevit2(variant, cfg_variant=None, pretrained=False, **kwargs): + return build_model_with_cfg( + ByobNet, variant, pretrained, + model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], + feature_cfg=dict(flatten_sequential=True), + **kwargs) + + @register_model def mobilevit_xxs(pretrained=False, **kwargs): return _create_mobilevit('mobilevit_xxs', pretrained=pretrained, **kwargs) @@ -269,4 +625,75 @@ def mobilevit_s(pretrained=False, **kwargs): @register_model def semobilevit_s(pretrained=False, **kwargs): - return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs) \ No newline at end of file + return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_050(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_050', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_075(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_075', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_100(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_100', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_125(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_125', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_150(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_150', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_175(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_175', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_200(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_200', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_150_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_150_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_175_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_175_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_200_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_200_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_150_384_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_150_384_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_175_384_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_175_384_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_200_384_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_200_384_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs) \ No newline at end of file diff --git a/timm/models/poolformer.py b/timm/models/poolformer.py index 17d657b0..a95195b4 100644 --- a/timm/models/poolformer.py +++ b/timm/models/poolformer.py @@ -26,7 +26,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp +from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp, GroupNorm1 from .registry import register_model @@ -80,15 +80,6 @@ class PatchEmbed(nn.Module): return x -class GroupNorm1(nn.GroupNorm): - """ Group Normalization with 1 group. - Input: tensor in shape [B, C, H, W] - """ - - def __init__(self, num_channels, **kwargs): - super().__init__(1, num_channels, **kwargs) - - class Pooling(nn.Module): def __init__(self, pool_size=3): super().__init__() From db0cee991028e772c0131e809da1e9e5ea60c568 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 14:43:27 -0700 Subject: [PATCH 30/45] Refactor cspnet configuration using dataclasses, update feature extraction for new cs3 variants. --- timm/models/cspnet.py | 710 ++++++++++++++++++++++++++---------------- 1 file changed, 448 insertions(+), 262 deletions(-) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 4591f101..f0a26baf 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -12,7 +12,10 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage Hacked together by / Copyright 2020 Ross Wightman """ +import collections.abc +from dataclasses import dataclass, field, asdict from functools import partial +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -20,7 +23,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP -from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, get_norm_act_layer +from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, create_act_layer, make_divisible from .registry import register_model @@ -58,218 +61,278 @@ default_cfgs = { ), 'darknetaa53': _cfg(url=''), + 'cs3darknet_s': _cfg( + url=''), 'cs3darknet_m': _cfg( url=''), 'cs3darknet_l': _cfg( url=''), + 'cs3darknet_x': _cfg( + url=''), + + 'cs3darknet_focus_s': _cfg( + url=''), 'cs3darknet_focus_m': _cfg( url=''), 'cs3darknet_focus_l': _cfg( url=''), + 'cs3darknet_focus_x': _cfg( + url=''), + + 'cs3sedarknet_xdw': _cfg( + url=''), } +@dataclass +class CspStemCfg: + out_chs: Union[int, Tuple[int, ...]] = 32 + stride: Union[int, Tuple[int, ...]] = 2 + kernel_size: int = 3 + padding: Union[int, str] = '' + pool: Optional[str] = '' + + +def _pad_arg(x, n): + # pads an argument tuple to specified n by padding with last value + if not isinstance(x, (tuple, list)): + x = (x,) + curr_n = len(x) + pad_n = n - curr_n + if pad_n <= 0: + return x[:n] + return tuple(x + (x[-1],) * pad_n) + + +@dataclass +class CspStagesCfg: + depth: Tuple[int, ...] = (3, 3, 5, 2) # block depth (number of block repeats in stages) + out_chs: Tuple[int, ...] = (128, 256, 512, 1024) # number of output channels for blocks in stage + stride: Union[int, Tuple[int, ...]] = 2 # stride of stage + groups: Union[int, Tuple[int, ...]] = 1 # num kxk conv groups + block_ratio: Union[float, Tuple[float, ...]] = 1.0 + bottle_ratio: Union[float, Tuple[float, ...]] = 1. # bottleneck-ratio of blocks in stage + avg_down: Union[bool, Tuple[bool, ...]] = False + attn_layer: Optional[Union[str, Tuple[str, ...]]] = None + stage_type: Union[str, Tuple[str]] = 'csp' # stage type ('csp', 'cs2', 'dark') + block_type: Union[str, Tuple[str]] = 'bottle' # blocks type for stages ('bottle', 'dark') + + # cross-stage only + expand_ratio: Union[float, Tuple[float, ...]] = 1.0 + cross_linear: Union[bool, Tuple[bool, ...]] = False + down_growth: Union[bool, Tuple[bool, ...]] = False + + def __post_init__(self): + n = len(self.depth) + assert len(self.out_chs) == n + self.stride = _pad_arg(self.stride, n) + self.groups = _pad_arg(self.groups, n) + self.block_ratio = _pad_arg(self.block_ratio, n) + self.bottle_ratio = _pad_arg(self.bottle_ratio, n) + self.avg_down = _pad_arg(self.avg_down, n) + self.attn_layer = _pad_arg(self.attn_layer, n) + self.stage_type = _pad_arg(self.stage_type, n) + self.block_type = _pad_arg(self.block_type, n) + + self.expand_ratio = _pad_arg(self.expand_ratio, n) + self.cross_linear = _pad_arg(self.cross_linear, n) + self.down_growth = _pad_arg(self.down_growth, n) + + +@dataclass +class CspModelCfg: + stem: CspStemCfg + stages: CspStagesCfg + zero_init_last: bool = True # zero init last weight (usually bn) in residual path + act_layer: str = 'relu' + norm_layer: str = 'batchnorm' + aa_layer: Optional[str] = None # FIXME support string factory for this + + +def _cs3darknet_cfg(width_multiplier=1.0, depth_multiplier=1.0, avg_down=False, act_layer='silu', focus=False): + if focus: + stem_cfg = CspStemCfg( + out_chs=make_divisible(64 * width_multiplier), + kernel_size=6, stride=2, padding=2, pool='') + else: + stem_cfg = CspStemCfg( + out_chs=tuple([make_divisible(c * width_multiplier) for c in (32, 64)]), + kernel_size=3, stride=2, pool='') + return CspModelCfg( + stem=stem_cfg, + stages=CspStagesCfg( + out_chs=tuple([make_divisible(c * width_multiplier) for c in (128, 256, 512, 1024)]), + depth=tuple([int(d * depth_multiplier) for d in (3, 6, 9, 3)]), + stride=2, + bottle_ratio=1., + block_ratio=0.5, + avg_down=avg_down, + stage_type='cs3', + block_type='dark', + ), + act_layer=act_layer, + ) + + model_cfgs = dict( - cspresnet50=dict( - stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'), - stage=dict( - out_chs=(128, 256, 512, 1024), + cspresnet50=CspModelCfg( + stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'), + stages=CspStagesCfg( depth=(3, 3, 5, 2), - stride=(1,) + (2,) * 3, - exp_ratio=(2.,) * 4, - bottle_ratio=(0.5,) * 4, - block_ratio=(1.,) * 4, + out_chs=(128, 256, 512, 1024), + stride=(1, 2), + expand_ratio=2., + bottle_ratio=0.5, cross_linear=True, - ) + ), ), - cspresnet50d=dict( - stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'), - stage=dict( - out_chs=(128, 256, 512, 1024), + cspresnet50d=CspModelCfg( + stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'), + stages=CspStagesCfg( depth=(3, 3, 5, 2), - stride=(1,) + (2,) * 3, - exp_ratio=(2.,) * 4, - bottle_ratio=(0.5,) * 4, - block_ratio=(1.,) * 4, + out_chs=(128, 256, 512, 1024), + stride=(1,) + (2,), + expand_ratio=2., + bottle_ratio=0.5, + block_ratio=1., cross_linear=True, ) ), - cspresnet50w=dict( - stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'), - stage=dict( - out_chs=(256, 512, 1024, 2048), + cspresnet50w=CspModelCfg( + stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'), + stages=CspStagesCfg( depth=(3, 3, 5, 2), - stride=(1,) + (2,) * 3, - exp_ratio=(1.,) * 4, - bottle_ratio=(0.25,) * 4, - block_ratio=(0.5,) * 4, + out_chs=(256, 512, 1024, 2048), + stride=(1,) + (2,), + expand_ratio=1., + bottle_ratio=0.25, + block_ratio=0.5, cross_linear=True, ) ), - cspresnext50=dict( - stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'), - stage=dict( - out_chs=(256, 512, 1024, 2048), + cspresnext50=CspModelCfg( + stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'), + stages=CspStagesCfg( depth=(3, 3, 5, 2), - stride=(1,) + (2,) * 3, - groups=(32,) * 4, - exp_ratio=(1.,) * 4, - bottle_ratio=(1.,) * 4, - block_ratio=(0.5,) * 4, + out_chs=(256, 512, 1024, 2048), + stride=(1,) + (2,), + groups=32, + expand_ratio=1., + bottle_ratio=1., + block_ratio=0.5, cross_linear=True, ) ), - cspdarknet53=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( - out_chs=(64, 128, 256, 512, 1024), + cspdarknet53=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( depth=(1, 2, 8, 8, 4), - stride=(2,) * 5, - exp_ratio=(2.,) + (1.,) * 4, - bottle_ratio=(0.5,) + (1.0,) * 4, - block_ratio=(1.,) + (0.5,) * 4, + out_chs=(64, 128, 256, 512, 1024), + stride=2, + expand_ratio=(2.,) + (1.,), + bottle_ratio=(0.5,) + (1.,), + block_ratio=(1.,) + (0.5,), down_growth=True, - ) + block_type='dark', + ), + act_layer='leaky_relu', ), - darknet17=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( - out_chs=(64, 128, 256, 512, 1024), + darknet17=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( depth=(1,) * 5, - stride=(2,) * 5, - bottle_ratio=(0.5,) * 5, - block_ratio=(1.,) * 5, - ) - ), - darknet21=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( out_chs=(64, 128, 256, 512, 1024), - depth=(1, 1, 1, 2, 2), - stride=(2,) * 5, - bottle_ratio=(0.5,) * 5, - block_ratio=(1.,) * 5, - ) + stride=(2,), + bottle_ratio=(0.5,), + block_ratio=(1.,), + stage_type='dark', + block_type='dark', + ), + act_layer='leaky_relu', ), - sedarknet21=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( - out_chs=(64, 128, 256, 512, 1024), + darknet21=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( depth=(1, 1, 1, 2, 2), - stride=(2,) * 5, - bottle_ratio=(0.5,) * 5, - block_ratio=(1.,) * 5, - attn_layer=('se',) * 5, - ) - ), - darknet53=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( out_chs=(64, 128, 256, 512, 1024), - depth=(1, 2, 8, 8, 4), - stride=(2,) * 5, - bottle_ratio=(0.5,) * 5, - block_ratio=(1.,) * 5, - ) - ), + stride=(2,), + bottle_ratio=(0.5,), + block_ratio=(1.,), + stage_type='dark', + block_type='dark', - darknetaa53=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( - out_chs=(64, 128, 256, 512, 1024), - depth=(1, 2, 8, 8, 4), - stride=(2,) * 5, - bottle_ratio=(0.5,) * 5, - block_ratio=(1.,) * 5, - avg_down=True, ), + act_layer='leaky_relu', ), + sedarknet21=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( + depth=(1, 1, 1, 2, 2), + out_chs=(64, 128, 256, 512, 1024), + stride=2, + bottle_ratio=0.5, + block_ratio=1., + attn_layer='se', + stage_type='dark', + block_type='dark', - cs3darknet_m=dict( - stem=dict(out_chs=(24, 48), kernel_size=3, stride=2, pool=''), - stage=dict( - out_chs=(96, 192, 384, 768), - depth=(2, 4, 6, 2), - stride=(2,) * 4, - bottle_ratio=(1.,) * 4, - block_ratio=(0.5,) * 4, - avg_down=False, ), + act_layer='leaky_relu', ), - cs3darknet_l=dict( - stem=dict(out_chs=(32, 64), kernel_size=3, stride=2, pool=''), - stage=dict( - out_chs=(128, 256, 512, 1024), - depth=(3, 6, 9, 3), - stride=(2,) * 4, - bottle_ratio=(1.,) * 4, - block_ratio=(0.5,) * 4, - avg_down=False, + darknet53=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( + depth=(1, 2, 8, 8, 4), + out_chs=(64, 128, 256, 512, 1024), + stride=2, + bottle_ratio=0.5, + block_ratio=1., + stage_type='dark', + block_type='dark', ), + act_layer='leaky_relu', ), - - cs3darknet_focus_m=dict( - stem=dict(out_chs=48, kernel_size=6, stride=2, padding=2, pool=''), - stage=dict( - out_chs=(96, 192, 384, 768), - depth=(2, 4, 6, 2), - stride=(2,) * 4, - bottle_ratio=(1.,) * 4, - block_ratio=(0.5,) * 4, - avg_down=False, + darknetaa53=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( + depth=(1, 2, 8, 8, 4), + out_chs=(64, 128, 256, 512, 1024), + stride=2, + bottle_ratio=0.5, + block_ratio=1., + avg_down=True, + stage_type='dark', + block_type='dark', ), + act_layer='leaky_relu', ), - cs3darknet_focus_l=dict( - stem=dict(out_chs=64, kernel_size=6, stride=2, padding=2, pool=''), - stage=dict( - out_chs=(128, 256, 512, 1024), - depth=(3, 6, 9, 3), - stride=(2,) * 4, - bottle_ratio=(1.,) * 4, - block_ratio=(0.5,) * 4, - avg_down=False, - ), - ) -) + cs3darknet_s=_cs3darknet_cfg(width_multiplier=0.5, depth_multiplier=0.5), + cs3darknet_m=_cs3darknet_cfg(width_multiplier=0.75, depth_multiplier=0.67), + cs3darknet_l=_cs3darknet_cfg(), + cs3darknet_x=_cs3darknet_cfg(width_multiplier=1.25, depth_multiplier=1.33), -def create_stem( - in_chans=3, - out_chs=32, - kernel_size=3, - stride=2, - pool='', - padding='', - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - aa_layer=None -): - stem = nn.Sequential() - if not isinstance(out_chs, (tuple, list)): - out_chs = [out_chs] - assert len(out_chs) - in_c = in_chans - for i, out_c in enumerate(out_chs): - conv_name = f'conv{i + 1}' - stem.add_module(conv_name, ConvNormAct( - in_c, out_c, kernel_size, - stride=stride if i == 0 else 1, - padding=padding if i == 0 else '', - act_layer=act_layer, - norm_layer=norm_layer - )) - in_c = out_c - last_conv = conv_name - if pool: - if aa_layer is not None: - stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1)) - stem.add_module('aa', aa_layer(channels=in_c, stride=2)) - else: - stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) - return stem, dict(num_chs=in_c, reduction=stride, module='.'.join(['stem', last_conv])) + cs3darknet_focus_s=_cs3darknet_cfg(width_multiplier=0.5, depth_multiplier=0.5, focus=True), + cs3darknet_focus_m=_cs3darknet_cfg(width_multiplier=0.75, depth_multiplier=0.67, focus=True), + cs3darknet_focus_l=_cs3darknet_cfg(focus=True), + cs3darknet_focus_x=_cs3darknet_cfg(width_multiplier=1.25, depth_multiplier=1.33, focus=True), + + cs3sedarknet_xdw=CspModelCfg( + stem=CspStemCfg(out_chs=(32, 64), kernel_size=3, stride=2, pool=''), + stages=CspStagesCfg( + depth=(3, 6, 12, 4), + out_chs=(256, 512, 1024, 2048), + stride=2, + groups=(1, 1, 256, 512), + bottle_ratio=0.5, + block_ratio=0.5, + attn_layer='se', + ), + ), +) -class ResBottleneck(nn.Module): +class BottleneckBlock(nn.Module): """ ResNe(X)t Bottleneck Block """ @@ -286,9 +349,9 @@ class ResBottleneck(nn.Module): attn_layer=None, aa_layer=None, drop_block=None, - drop_path=None + drop_path=0. ): - super(ResBottleneck, self).__init__() + super(BottleneckBlock, self).__init__() mid_chs = int(round(out_chs * bottle_ratio)) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) @@ -299,8 +362,8 @@ class ResBottleneck(nn.Module): self.attn2 = create_attn(attn_layer, channels=mid_chs) if not attn_last else None self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs) self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None - self.drop_path = drop_path - self.act3 = act_layer() + self.drop_path = DropPath(drop_path) if drop_path else nn.Identity() + self.act3 = create_act_layer(act_layer) def zero_init_last(self): nn.init.zeros_(self.conv3.bn.weight) @@ -314,9 +377,7 @@ class ResBottleneck(nn.Module): x = self.conv3(x) if self.attn3 is not None: x = self.attn3(x) - if self.drop_path is not None: - x = self.drop_path(x) - x = x + shortcut + x = self.drop_path(x) + shortcut # FIXME partial shortcut needed if first block handled as per original, not used for my current impl #x[:, :shortcut.size(1)] += shortcut x = self.act3(x) @@ -339,7 +400,7 @@ class DarkBlock(nn.Module): attn_layer=None, aa_layer=None, drop_block=None, - drop_path=None + drop_path=0. ): super(DarkBlock, self).__init__() mid_chs = int(round(out_chs * bottle_ratio)) @@ -349,7 +410,7 @@ class DarkBlock(nn.Module): mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups, aa_layer=aa_layer, drop_layer=drop_block, **ckwargs) self.attn = create_attn(attn_layer, channels=out_chs, act_layer=act_layer) - self.drop_path = drop_path + self.drop_path = DropPath(drop_path) if drop_path else nn.Identity() def zero_init_last(self): nn.init.zeros_(self.conv2.bn.weight) @@ -360,9 +421,7 @@ class DarkBlock(nn.Module): x = self.conv2(x) if self.attn is not None: x = self.attn(x) - if self.drop_path is not None: - x = self.drop_path(x) - x = x + shortcut + x = self.drop_path(x) + shortcut return x @@ -377,27 +436,27 @@ class CrossStage(nn.Module): depth, block_ratio=1., bottle_ratio=1., - exp_ratio=1., + expand_ratio=1., groups=1, first_dilation=None, avg_down=False, down_growth=False, cross_linear=False, block_dpr=None, - block_fn=ResBottleneck, + block_fn=BottleneckBlock, **block_kwargs ): super(CrossStage, self).__init__() first_dilation = first_dilation or dilation down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels - self.exp_chs = exp_chs = int(round(out_chs * exp_ratio)) + self.expand_chs = exp_chs = int(round(out_chs * expand_ratio)) block_out_chs = int(round(out_chs * block_ratio)) conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) if stride != 1 or first_dilation != dilation: if avg_down: self.conv_down = nn.Sequential( - nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling + nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ) else: @@ -417,9 +476,15 @@ class CrossStage(nn.Module): self.blocks = nn.Sequential() for i in range(depth): - drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None self.blocks.add_module(str(i), block_fn( - prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + in_chs=prev_chs, + out_chs=block_out_chs, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + drop_path=block_dpr[i] if block_dpr is not None else 0., + **block_kwargs + )) prev_chs = block_out_chs # transition convs @@ -429,7 +494,7 @@ class CrossStage(nn.Module): def forward(self, x): x = self.conv_down(x) x = self.conv_exp(x) - xs, xb = x.split(self.exp_chs // 2, dim=1) + xs, xb = x.split(self.expand_chs // 2, dim=1) xb = self.blocks(xb) xb = self.conv_transition_b(xb).contiguous() out = self.conv_transition(torch.cat([xs, xb], dim=1)) @@ -449,27 +514,27 @@ class CrossStage3(nn.Module): depth, block_ratio=1., bottle_ratio=1., - exp_ratio=1., + expand_ratio=1., groups=1, first_dilation=None, avg_down=False, down_growth=False, cross_linear=False, block_dpr=None, - block_fn=ResBottleneck, + block_fn=BottleneckBlock, **block_kwargs ): super(CrossStage3, self).__init__() first_dilation = first_dilation or dilation down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels - self.exp_chs = exp_chs = int(round(out_chs * exp_ratio)) + self.expand_chs = exp_chs = int(round(out_chs * expand_ratio)) block_out_chs = int(round(out_chs * block_ratio)) conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) if stride != 1 or first_dilation != dilation: if avg_down: self.conv_down = nn.Sequential( - nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling + nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ) else: @@ -487,9 +552,15 @@ class CrossStage3(nn.Module): self.blocks = nn.Sequential() for i in range(depth): - drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None self.blocks.add_module(str(i), block_fn( - prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + in_chs=prev_chs, + out_chs=block_out_chs, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + drop_path=block_dpr[i] if block_dpr is not None else 0., + **block_kwargs + )) prev_chs = block_out_chs # transition convs @@ -498,7 +569,7 @@ class CrossStage3(nn.Module): def forward(self, x): x = self.conv_down(x) x = self.conv_exp(x) - x1, x2 = x.split(self.exp_chs // 2, dim=1) + x1, x2 = x.split(self.expand_chs // 2, dim=1) x1 = self.blocks(x1) out = self.conv_transition(torch.cat([x1, x2], dim=1)) return out @@ -519,7 +590,7 @@ class DarkStage(nn.Module): groups=1, first_dilation=None, avg_down=False, - block_fn=ResBottleneck, + block_fn=BottleneckBlock, block_dpr=None, **block_kwargs ): @@ -529,7 +600,7 @@ class DarkStage(nn.Module): if avg_down: self.conv_down = nn.Sequential( - nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling + nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ) else: @@ -541,9 +612,15 @@ class DarkStage(nn.Module): block_out_chs = int(round(out_chs * block_ratio)) self.blocks = nn.Sequential() for i in range(depth): - drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None self.blocks.add_module(str(i), block_fn( - prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + in_chs=prev_chs, + out_chs=block_out_chs, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + drop_path=block_dpr[i] if block_dpr is not None else 0., + **block_kwargs + )) prev_chs = block_out_chs def forward(self, x): @@ -552,38 +629,131 @@ class DarkStage(nn.Module): return x -def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.): - # get per stage args for stage and containing blocks, calculate strides to meet target output_stride - num_stages = len(cfg['depth']) - if 'groups' not in cfg: - cfg['groups'] = (1,) * num_stages - if 'down_growth' in cfg and not isinstance(cfg['down_growth'], (list, tuple)): - cfg['down_growth'] = (cfg['down_growth'],) * num_stages - if 'cross_linear' in cfg and not isinstance(cfg['cross_linear'], (list, tuple)): - cfg['cross_linear'] = (cfg['cross_linear'],) * num_stages - if 'avg_down' in cfg and not isinstance(cfg['avg_down'], (list, tuple)): - cfg['avg_down'] = (cfg['avg_down'],) * num_stages - cfg['block_dpr'] = [None] * num_stages if not drop_path_rate else \ - [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg['depth'])).split(cfg['depth'])] - stage_strides = [] - stage_dilations = [] - stage_first_dilations = [] +def create_csp_stem( + in_chans=3, + out_chs=32, + kernel_size=3, + stride=2, + pool='', + padding='', + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + aa_layer=None +): + stem = nn.Sequential() + feature_info = [] + if not isinstance(out_chs, (tuple, list)): + out_chs = [out_chs] + stem_depth = len(out_chs) + assert stem_depth + assert stride in (1, 2, 4) + prev_feat = None + prev_chs = in_chans + last_idx = stem_depth - 1 + stem_stride = 1 + for i, chs in enumerate(out_chs): + conv_name = f'conv{i + 1}' + conv_stride = 2 if (i == 0 and stride > 1) or (i == last_idx and stride > 2 and not pool) else 1 + if conv_stride > 1 and prev_feat is not None: + feature_info.append(prev_feat) + stem.add_module(conv_name, ConvNormAct( + prev_chs, chs, kernel_size, + stride=conv_stride, + padding=padding if i == 0 else '', + act_layer=act_layer, + norm_layer=norm_layer + )) + stem_stride *= conv_stride + prev_chs = chs + prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', conv_name])) + if pool: + assert stride > 2 + if prev_feat is not None: + feature_info.append(prev_feat) + if aa_layer is not None: + stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1)) + stem.add_module('aa', aa_layer(channels=prev_chs, stride=2)) + pool_name = 'aa' + else: + stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + pool_name = 'pool' + stem_stride *= 2 + prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', pool_name])) + feature_info.append(prev_feat) + return stem, feature_info + + +def _get_stage_fn(stage_type: str, stage_args): + assert stage_type in ('dark', 'csp', 'cs3') + if stage_type == 'dark': + stage_args.pop('expand_ratio', None) + stage_args.pop('cross_linear', None) + stage_args.pop('down_growth', None) + stage_fn = DarkStage + elif stage_type == 'csp': + stage_fn = CrossStage + else: + stage_fn = CrossStage3 + return stage_fn, stage_args + + +def _get_block_fn(stage_type: str, stage_args): + assert stage_type in ('dark', 'bottle') + if stage_type == 'dark': + return DarkBlock, stage_args + else: + return BottleneckBlock, stage_args + + +def create_csp_stages( + cfg: CspModelCfg, + drop_path_rate: float, + output_stride: int, + stem_feat: Dict[str, Any] +): + cfg_dict = asdict(cfg.stages) + num_stages = len(cfg.stages.depth) + cfg_dict['block_dpr'] = [None] * num_stages if not drop_path_rate else \ + [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.stages.depth)).split(cfg.stages.depth)] + stage_args = [dict(zip(cfg_dict.keys(), values)) for values in zip(*cfg_dict.values())] + block_kwargs = dict( + act_layer=cfg.act_layer, + norm_layer=cfg.norm_layer, + aa_layer=cfg.aa_layer + ) + dilation = 1 - for cfg_stride in cfg['stride']: - stage_first_dilations.append(dilation) - if curr_stride >= output_stride: - dilation *= cfg_stride + net_stride = stem_feat['reduction'] + prev_chs = stem_feat['num_chs'] + prev_feat = stem_feat + feature_info = [] + stages = [] + for stage_idx, stage_args in enumerate(stage_args): + stage_fn, stage_args = _get_stage_fn(stage_args.pop('stage_type'), stage_args) + block_fn, stage_args = _get_block_fn(stage_args.pop('block_type'), stage_args) + stride = stage_args.pop('stride') + if stride != 1 and prev_feat: + feature_info.append(prev_feat) + if net_stride >= output_stride and stride > 1: + dilation *= stride stride = 1 - else: - stride = cfg_stride - curr_stride *= stride - stage_strides.append(stride) - stage_dilations.append(dilation) - cfg['stride'] = stage_strides - cfg['dilation'] = stage_dilations - cfg['first_dilation'] = stage_first_dilations - stage_args = [dict(zip(cfg.keys(), values)) for values in zip(*cfg.values())] - return stage_args + net_stride *= stride + first_dilation = 1 if dilation in (1, 2) else 2 + + stages += [stage_fn( + prev_chs, + **stage_args, + stride=stride, + first_dilation=first_dilation, + dilation=dilation, + block_fn=block_fn, + **block_kwargs, + )] + prev_chs = stage_args['out_chs'] + prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}') + + feature_info.append(prev_feat) + return nn.Sequential(*stages), feature_info class CspNet(nn.Module): @@ -598,43 +768,39 @@ class CspNet(nn.Module): def __init__( self, - cfg, + cfg: CspModelCfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', - act_layer=nn.LeakyReLU, - norm_layer=nn.BatchNorm2d, - aa_layer=None, drop_rate=0., drop_path_rate=0., - zero_init_last=True, - stage_fn=CrossStage, - block_fn=ResBottleneck): + zero_init_last=True + ): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate assert output_stride in (8, 16, 32) - layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) + layer_args = dict( + act_layer=cfg.act_layer, + norm_layer=cfg.norm_layer, + aa_layer=cfg.aa_layer + ) + self.feature_info = [] # Construct the stem - self.stem, stem_feat_info = create_stem(in_chans, **cfg['stem'], **layer_args) - self.feature_info = [stem_feat_info] - prev_chs = stem_feat_info['num_chs'] - curr_stride = stem_feat_info['reduction'] # reduction does not include pool - if cfg['stem']['pool']: - curr_stride *= 2 + self.stem, stem_feat_info = create_csp_stem(in_chans, **asdict(cfg.stem), **layer_args) + self.feature_info.extend(stem_feat_info[:-1]) # Construct the stages - per_stage_args = _cfg_to_stage_args( - cfg['stage'], curr_stride=curr_stride, output_stride=output_stride, drop_path_rate=drop_path_rate) - self.stages = nn.Sequential() - for i, sa in enumerate(per_stage_args): - self.stages.add_module( - str(i), stage_fn(prev_chs, **sa, **layer_args, block_fn=block_fn)) - prev_chs = sa['out_chs'] - curr_stride *= sa['stride'] - self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] + self.stages, stage_feat_info = create_csp_stages( + cfg, + drop_path_rate=drop_path_rate, + output_stride=output_stride, + stem_feat=stem_feat_info[-1], + ) + prev_chs = stage_feat_info[-1]['num_chs'] + self.feature_info.extend(stage_feat_info) # Construct the head self.num_features = prev_chs @@ -729,54 +895,74 @@ def cspresnext50(pretrained=False, **kwargs): @register_model def cspdarknet53(pretrained=False, **kwargs): - return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs) + return _create_cspnet('cspdarknet53', pretrained=pretrained, **kwargs) @register_model def darknet17(pretrained=False, **kwargs): - return _create_cspnet('darknet17', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + return _create_cspnet('darknet17', pretrained=pretrained, **kwargs) @register_model def darknet21(pretrained=False, **kwargs): - return _create_cspnet('darknet21', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + return _create_cspnet('darknet21', pretrained=pretrained, **kwargs) @register_model def sedarknet21(pretrained=False, **kwargs): - return _create_cspnet('sedarknet21', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + return _create_cspnet('sedarknet21', pretrained=pretrained, **kwargs) @register_model def darknet53(pretrained=False, **kwargs): - return _create_cspnet('darknet53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + return _create_cspnet('darknet53', pretrained=pretrained, **kwargs) @register_model def darknetaa53(pretrained=False, **kwargs): - return _create_cspnet( - 'darknetaa53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + return _create_cspnet('darknetaa53', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_s(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_s', pretrained=pretrained, **kwargs) @register_model def cs3darknet_m(pretrained=False, **kwargs): - return _create_cspnet( - 'cs3darknet_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) + return _create_cspnet('cs3darknet_m', pretrained=pretrained, **kwargs) @register_model def cs3darknet_l(pretrained=False, **kwargs): - return _create_cspnet( - 'cs3darknet_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) + return _create_cspnet('cs3darknet_l', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_x(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_x', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_focus_s(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_focus_s', pretrained=pretrained, **kwargs) @register_model def cs3darknet_focus_m(pretrained=False, **kwargs): - return _create_cspnet( - 'cs3darknet_focus_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) + return _create_cspnet('cs3darknet_focus_m', pretrained=pretrained, **kwargs) @register_model def cs3darknet_focus_l(pretrained=False, **kwargs): - return _create_cspnet( - 'cs3darknet_focus_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) \ No newline at end of file + return _create_cspnet('cs3darknet_focus_l', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_focus_x(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_focus_x', pretrained=pretrained, **kwargs) + + +@register_model +def cs3sedarknet_xdw(pretrained=False, **kwargs): + return _create_cspnet('cs3sedarknet_xdw', pretrained=pretrained, **kwargs) From 11060f84c51d813d900ddf3a7b178b3e4fe87fb3 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 7 Jul 2022 14:44:55 -0700 Subject: [PATCH 31/45] make train.py compatible with torchrun --- train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train.py b/train.py index 285981fd..e5d40566 100755 --- a/train.py +++ b/train.py @@ -355,6 +355,8 @@ def main(): args.world_size = 1 args.rank = 0 # global rank if args.distributed: + if 'LOCAL_RANK' in os.environ: + args.local_rank = int(os.getenv('LOCAL_RANK')) args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') From 28e01520434930b420aedfc979f0f2fcee513628 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 15:13:06 -0700 Subject: [PATCH 32/45] Add --no-retry flag to benchmark.py to skip batch_size decay and retry on error. Fix #1226. Update deepspeed profile usage for latest DS releases. Fix # 1333 --- benchmark.py | 53 ++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/benchmark.py b/benchmark.py index 1362eeab..74f09489 100755 --- a/benchmark.py +++ b/benchmark.py @@ -71,6 +71,8 @@ parser.add_argument('--bench', default='both', type=str, help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'") parser.add_argument('--detail', action='store_true', default=False, help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False') +parser.add_argument('--no-retry', action='store_true', default=False, + help='Do not decay batch size and retry on error.') parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', help='Output csv file for validation results (summary)') parser.add_argument('--num-warm-iter', default=10, type=int, @@ -169,10 +171,9 @@ def resolve_precision(precision: str): def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False): - macs, _ = get_model_profile( + _, macs, _ = get_model_profile( model=model, - input_res=(batch_size,) + input_size, # input shape or input to the input_constructor - input_constructor=None, # if specified, a constructor taking input_res is used as input to the model + input_shape=(batch_size,) + input_size, # input shape/resolution print_profile=detailed, # prints the model graph with the measured profile attached to each module detailed=detailed, # print the detailed profile warm_up=10, # the number of warm-ups before measuring the time of each module @@ -197,8 +198,19 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False class BenchmarkRunner: def __init__( - self, model_name, detail=False, device='cuda', torchscript=False, aot_autograd=False, precision='float32', - fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs): + self, + model_name, + detail=False, + device='cuda', + torchscript=False, + aot_autograd=False, + precision='float32', + fuser='', + num_warm_iter=10, + num_bench_iter=50, + use_train_size=False, + **kwargs + ): self.model_name = model_name self.detail = detail self.device = device @@ -256,7 +268,13 @@ class BenchmarkRunner: class InferenceBenchmarkRunner(BenchmarkRunner): - def __init__(self, model_name, device='cuda', torchscript=False, **kwargs): + def __init__( + self, + model_name, + device='cuda', + torchscript=False, + **kwargs + ): super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs) self.model.eval() @@ -325,7 +343,13 @@ class InferenceBenchmarkRunner(BenchmarkRunner): class TrainBenchmarkRunner(BenchmarkRunner): - def __init__(self, model_name, device='cuda', torchscript=False, **kwargs): + def __init__( + self, + model_name, + device='cuda', + torchscript=False, + **kwargs + ): super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs) self.model.train() @@ -492,7 +516,7 @@ def decay_batch_exp(batch_size, factor=0.5, divisor=16): return max(0, int(out_batch_size)) -def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs): +def _try_run(model_name, bench_fn, bench_kwargs, initial_batch_size, no_batch_size_retry=False): batch_size = initial_batch_size results = dict() error_str = 'Unknown' @@ -507,8 +531,11 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs): if 'channels_last' in error_str: _logger.error(f'{model_name} not supported in channels_last, skipping.') break - _logger.warning(f'"{error_str}" while running benchmark. Reducing batch size to {batch_size} for retry.') + _logger.error(f'"{error_str}" while running benchmark.') + if no_batch_size_retry: + break batch_size = decay_batch_exp(batch_size) + _logger.warning(f'Reducing batch size to {batch_size} for retry.') results['error'] = error_str return results @@ -550,7 +577,13 @@ def benchmark(args): model_results = OrderedDict(model=model) for prefix, bench_fn in zip(prefixes, bench_fns): - run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs) + run_results = _try_run( + model, + bench_fn, + bench_kwargs=bench_kwargs, + initial_batch_size=batch_size, + no_batch_size_retry=args.no_retry, + ) if prefix and 'error' not in run_results: run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()} model_results.update(run_results) From 500c190860bb80da348dd719dd8f0b73e44f0854 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 15:15:25 -0700 Subject: [PATCH 33/45] Add --aot-autograd (functorch efficient mem fusion) support to validate.py --- validate.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/validate.py b/validate.py index 27b88299..708ac2e5 100755 --- a/validate.py +++ b/validate.py @@ -38,6 +38,12 @@ try: except AttributeError: pass +try: + from functorch.compile import memory_efficient_fusion + has_functorch = True +except ImportError as e: + has_functorch = False + torch.backends.cudnn.benchmark = True _logger = logging.getLogger('validate') @@ -101,8 +107,11 @@ parser.add_argument('--tf-preprocessing', action='store_true', default=False, help='Use Tensorflow preprocessing pipeline (require CPU TF installed') parser.add_argument('--use-ema', dest='use_ema', action='store_true', help='use ema version of weights if present') -parser.add_argument('--torchscript', dest='torchscript', action='store_true', - help='convert model torchscript for inference') +scripting_group = parser.add_mutually_exclusive_group() +scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', + help='torch.jit.script the full model') +scripting_group.add_argument('--aot-autograd', default=False, action='store_true', + help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', @@ -162,7 +171,10 @@ def validate(args): if args.torchscript: torch.jit.optimized_execution(True) - model = torch.jit.script(model) + model = torch.jit.trace(model, example_inputs=torch.randn((args.batch_size,) + data_config['input_size'])) + if args.aot_autograd: + assert has_functorch, "functorch is needed for --aot-autograd" + model = memory_efficient_fusion(model) model = model.cuda() if args.apex_amp: From 4670d375c6ac37457094fa5519079936328d1e67 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 15:21:29 -0700 Subject: [PATCH 34/45] Reorg benchmark.py import --- benchmark.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/benchmark.py b/benchmark.py index 74f09489..23047bb5 100755 --- a/benchmark.py +++ b/benchmark.py @@ -6,24 +6,23 @@ An inference and train step benchmark script for timm models. Hacked together by Ross Wightman (https://github.com/rwightman) """ import argparse -import os import csv import json -import time import logging -import torch -import torch.nn as nn -import torch.nn.parallel +import time from collections import OrderedDict from contextlib import suppress from functools import partial +import torch +import torch.nn as nn +import torch.nn.parallel + +from timm.data import resolve_data_config from timm.models import create_model, is_model, list_models from timm.optim import create_optimizer_v2 -from timm.data import resolve_data_config from timm.utils import setup_default_logging, set_jit_fuser - has_apex = False try: from apex import amp From 9be0c847154b7a20cfd3f51fc0c366099c007f5c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 15:33:53 -0700 Subject: [PATCH 35/45] Change set -> dict w/ None keys for dataset split synonym search, so always consistent if more than 1 exists. Fix #1224 --- timm/data/dataset_factory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index 194a597e..d0ac30b1 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -26,8 +26,8 @@ _TORCH_BASIC_DS = dict( kmnist=KMNIST, fashion_mnist=FashionMNIST, ) -_TRAIN_SYNONYM = {'train', 'training'} -_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'} +_TRAIN_SYNONYM = dict(train=None, training=None) +_EVAL_SYNONYM = dict(val=None, valid=None, validation=None, eval=None, evaluation=None) def _search_split(root, split): From 58621723bda1fe386e8eebd729e743e255f992eb Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 17:43:38 -0700 Subject: [PATCH 36/45] Add CrossStage3 DarkNet (cs3) weights --- timm/models/cspnet.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index f0a26baf..e8e8910e 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -57,30 +57,35 @@ default_cfgs = { 'sedarknet21': _cfg(url=''), 'darknet53': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknet53_256_c2ns-3aeff817.pth', - test_input_size=(3, 288, 288), test_crop_pct=1.0, interpolation='bicubic' + interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0, ), 'darknetaa53': _cfg(url=''), 'cs3darknet_s': _cfg( - url=''), + url='', interpolation='bicubic'), 'cs3darknet_m': _cfg( - url=''), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_m_c2ns-43f06604.pth', + interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95, + ), 'cs3darknet_l': _cfg( - url=''), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_l_c2ns-16220c5d.pth', + interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95), 'cs3darknet_x': _cfg( url=''), 'cs3darknet_focus_s': _cfg( - url=''), + url='', interpolation='bicubic'), 'cs3darknet_focus_m': _cfg( - url=''), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_m_c2ns-e23bed41.pth', + interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95), 'cs3darknet_focus_l': _cfg( - url=''), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_l_c2ns-65ef8888.pth', + interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95), 'cs3darknet_focus_x': _cfg( - url=''), + url='', interpolation='bicubic'), 'cs3sedarknet_xdw': _cfg( - url=''), + url='', interpolation='bicubic'), } From ce65a7b29fa39c3f4d09b03b515b2af31ec9aea5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 21:33:25 -0700 Subject: [PATCH 37/45] Update vit_relpos w/ some additional weights, some cleanup to match recent vit updates, more MLP log coord experiments. --- timm/models/vision_transformer_relpos.py | 194 +++++++++++++++++------ 1 file changed, 145 insertions(+), 49 deletions(-) diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 0c9ac989..52b3ce45 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -8,6 +8,7 @@ import math import logging from functools import partial from collections import OrderedDict +from dataclasses import dataclass from typing import Optional, Tuple import torch @@ -16,7 +17,7 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, named_apply +from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple from .registry import register_model @@ -47,9 +48,16 @@ default_cfgs = { 'vit_relpos_base_patch16_224': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'), + 'vit_srelpos_small_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_small_patch16_224-sw-6cdb8849.pth'), + 'vit_srelpos_medium_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_medium_patch16_224-sw-ad702b8c.pth'), + + 'vit_relpos_medium_patch16_cls_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_cls_224-sw-cfe8e259.pth'), 'vit_relpos_base_patch16_cls_224': _cfg( url=''), - 'vit_relpos_base_patch16_gapcls_224': _cfg( + 'vit_relpos_base_patch16_clsgap_224': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'), 'vit_relpos_small_patch16_rpn_224': _cfg(url=''), @@ -59,35 +67,43 @@ default_cfgs = { } -def gen_relative_position_index(win_size: Tuple[int, int], class_token: int = 0) -> torch.Tensor: - # cut and paste w/ modifications from swin / beit codebase - # cls to token & token 2 cls & cls to cls +def gen_relative_position_index( + q_size: Tuple[int, int], + k_size: Tuple[int, int] = None, + class_token: bool = False) -> torch.Tensor: + # Adapted with significant modifications from Swin / BeiT codebases # get pair-wise relative position index for each token inside the window - window_area = win_size[0] * win_size[1] - coords = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1) # 2, Wh, Ww - relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += win_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += win_size[1] - 1 - relative_coords[:, :, 0] *= 2 * win_size[1] - 1 + q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww + if k_size is None: + k_coords = q_coords + k_size = q_size + else: + # different q vs k sizes is a WIP + k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1) + relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2 + _, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0) + if class_token: - num_relative_distance = (2 * win_size[0] - 1) * (2 * win_size[1] - 1) + 3 - relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype) - relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + # handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias + # NOTE not intended or tested with MLP log-coords + max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1])) + num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3 + relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0]) relative_position_index[0, 0:] = num_relative_distance - 3 relative_position_index[0:, 0] = num_relative_distance - 2 relative_position_index[0, 0] = num_relative_distance - 1 - else: - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - return relative_position_index + + return relative_position_index.contiguous() def gen_relative_log_coords( win_size: Tuple[int, int], pretrained_win_size: Tuple[int, int] = (0, 0), - mode='swin' + mode='swin', ): - # as per official swin-v2 impl, supporting timm swin-v2-cr coords as well + assert mode in ('swin', 'cr', 'rw') + # as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32) relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32) relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) @@ -100,12 +116,22 @@ def gen_relative_log_coords( relative_coords_table[:, :, 0] /= (win_size[0] - 1) relative_coords_table[:, :, 1] /= (win_size[1] - 1) relative_coords_table *= 8 # normalize to -8, 8 - scale = math.log2(8) + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + 1.0 + relative_coords_table.abs()) / math.log2(8) else: - # FIXME we should support a form of normalization (to -1/1) for this mode? - scale = math.log2(math.e) - relative_coords_table = torch.sign(relative_coords_table) * torch.log2( - 1.0 + relative_coords_table.abs()) / scale + if mode == 'rw': + # cr w/ window size normalization -> [-1,1] log coords + relative_coords_table[:, :, 0] /= (win_size[0] - 1) + relative_coords_table[:, :, 1] /= (win_size[1] - 1) + relative_coords_table *= 8 # scale to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + 1.0 + relative_coords_table.abs()) + relative_coords_table /= math.log2(9) # -> [-1, 1] + else: + # mode == 'cr' + relative_coords_table = torch.sign(relative_coords_table) * torch.log( + 1.0 + relative_coords_table.abs()) + return relative_coords_table @@ -115,19 +141,29 @@ class RelPosMlp(nn.Module): window_size, num_heads=8, hidden_dim=128, - class_token=False, + prefix_tokens=0, mode='cr', pretrained_window_size=(0, 0) ): super().__init__() self.window_size = window_size self.window_area = self.window_size[0] * self.window_size[1] - self.class_token = 1 if class_token else 0 + self.prefix_tokens = prefix_tokens self.num_heads = num_heads self.bias_shape = (self.window_area,) * 2 + (num_heads,) - self.apply_sigmoid = mode == 'swin' + if mode == 'swin': + self.bias_act = nn.Sigmoid() + self.bias_gain = 16 + mlp_bias = (True, False) + elif mode == 'rw': + self.bias_act = nn.Tanh() + self.bias_gain = 4 + mlp_bias = True + else: + self.bias_act = nn.Identity() + self.bias_gain = None + mlp_bias = True - mlp_bias = (True, False) if mode == 'swin' else True self.mlp = Mlp( 2, # x, y hidden_features=hidden_dim, @@ -155,10 +191,11 @@ class RelPosMlp(nn.Module): self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.view(self.bias_shape) relative_position_bias = relative_position_bias.permute(2, 0, 1) - if self.apply_sigmoid: - relative_position_bias = 16 * torch.sigmoid(relative_position_bias) - if self.class_token: - relative_position_bias = F.pad(relative_position_bias, [self.class_token, 0, self.class_token, 0]) + relative_position_bias = self.bias_act(relative_position_bias) + if self.bias_gain is not None: + relative_position_bias = self.bias_gain * relative_position_bias + if self.prefix_tokens: + relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0]) return relative_position_bias.unsqueeze(0).contiguous() def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): @@ -167,18 +204,18 @@ class RelPosMlp(nn.Module): class RelPosBias(nn.Module): - def __init__(self, window_size, num_heads, class_token=False): + def __init__(self, window_size, num_heads, prefix_tokens=0): super().__init__() + assert prefix_tokens <= 1 self.window_size = window_size self.window_area = window_size[0] * window_size[1] - self.class_token = 1 if class_token else 0 - self.bias_shape = (self.window_area + self.class_token,) * 2 + (num_heads,) + self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,) - num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * self.class_token + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads)) self.register_buffer( "relative_position_index", - gen_relative_position_index(self.window_size, class_token=self.class_token), + gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0), persistent=False, ) @@ -306,11 +343,32 @@ class VisionTransformerRelPos(nn.Module): """ def __init__( - self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg', - embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-6, - class_token=False, fc_norm=False, rel_pos_type='mlp', shared_rel_pos=False, rel_pos_dim=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', - embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock): + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + global_pool='avg', + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + init_values=1e-6, + class_token=False, + fc_norm=False, + rel_pos_type='mlp', + rel_pos_dim=None, + shared_rel_pos=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + weight_init='skip', + embed_layer=PatchEmbed, + norm_layer=None, + act_layer=None, + block_fn=RelPosBlock + ): """ Args: img_size (int, tuple): input image size @@ -345,19 +403,22 @@ class VisionTransformerRelPos(nn.Module): self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 1 if class_token else 0 + self.num_prefix_tokens = 1 if class_token else 0 self.grad_checkpointing = False self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) feat_size = self.patch_embed.grid_size - rel_pos_args = dict(window_size=feat_size, class_token=class_token) + rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens) if rel_pos_type.startswith('mlp'): if rel_pos_dim: rel_pos_args['hidden_dim'] = rel_pos_dim + # FIXME experimenting with different relpos log coord configs if 'swin' in rel_pos_type: rel_pos_args['mode'] = 'swin' + elif 'rw' in rel_pos_type: + rel_pos_args['mode'] = 'rw' rel_pos_cls = partial(RelPosMlp, **rel_pos_args) else: rel_pos_cls = partial(RelPosBias, **rel_pos_args) @@ -367,7 +428,7 @@ class VisionTransformerRelPos(nn.Module): # NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both... rel_pos_cls = None - self.cls_token = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) if self.num_tokens else None + self.cls_token = nn.Parameter(torch.zeros(1, self.num_prefix_tokens, embed_dim)) if class_token else None dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ @@ -434,7 +495,7 @@ class VisionTransformerRelPos(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool: - x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) return x if pre_logits else self.head(x) @@ -502,6 +563,41 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs): return model +@register_model +def vit_srelpos_small_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, fc_norm=False, + rel_pos_dim=384, shared_rel_pos=True, **kwargs) + model = _create_vision_transformer_relpos('vit_srelpos_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_srelpos_medium_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=False, + rel_pos_dim=512, shared_rel_pos=True, **kwargs) + model = _create_vision_transformer_relpos( + 'vit_srelpos_medium_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_medium_patch16_cls_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-M/16) w/ relative log-coord position, class token present + """ + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=False, + rel_pos_dim=256, class_token=True, global_pool='token', **kwargs) + model = _create_vision_transformer_relpos( + 'vit_relpos_medium_patch16_cls_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present @@ -514,14 +610,14 @@ def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs): @register_model -def vit_relpos_base_patch16_gapcls_224(pretrained=False, **kwargs): +def vit_relpos_base_patch16_clsgap_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present NOTE this config is a bit of a mistake, class token was enabled but global avg-pool w/ fc-norm was not disabled Leaving here for comparisons w/ a future re-train as it performs quite well. """ model_kwargs = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, class_token=True, **kwargs) - model = _create_vision_transformer_relpos('vit_relpos_base_patch16_gapcls_224', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer_relpos('vit_relpos_base_patch16_clsgap_224', pretrained=pretrained, **model_kwargs) return model From 7c7ecd24923b19338ca083d56369193e153294f0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 22:01:24 -0700 Subject: [PATCH 38/45] Add --use-train-size flag to force use of train input_size (over test input size) for validation. Default test-time pooling to use train input size (fixes issues). --- timm/models/layers/test_time_pool.py | 2 +- validate.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/timm/models/layers/test_time_pool.py b/timm/models/layers/test_time_pool.py index 98c0bf53..5826d8c9 100644 --- a/timm/models/layers/test_time_pool.py +++ b/timm/models/layers/test_time_pool.py @@ -36,7 +36,7 @@ class TestTimePoolHead(nn.Module): return x.view(x.size(0), -1) -def apply_test_time_pool(model, config, use_test_size=True): +def apply_test_time_pool(model, config, use_test_size=False): test_time_pool = False if not hasattr(model, 'default_cfg') or not model.default_cfg: return model, False diff --git a/validate.py b/validate.py index 708ac2e5..7fa22b49 100755 --- a/validate.py +++ b/validate.py @@ -67,6 +67,8 @@ parser.add_argument('--img-size', default=None, type=int, metavar='N', help='Input image dimension, uses model default if empty') parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') +parser.add_argument('--use-train-size', action='store_true', default=False, + help='force use of train input size, even when test size is specified in pretrained cfg') parser.add_argument('--crop-pct', default=None, type=float, metavar='N', help='Input image center crop pct') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', @@ -164,10 +166,15 @@ def validate(args): param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) - data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True) + data_config = resolve_data_config( + vars(args), + model=model, + use_test_size=not args.use_train_size, + verbose=True + ) test_time_pool = False if args.test_pool: - model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) + model, test_time_pool = apply_test_time_pool(model, data_config) if args.torchscript: torch.jit.optimized_execution(True) From a1cb25066e26c8bf8fa410987b808899657c9e20 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 22:02:57 -0700 Subject: [PATCH 39/45] Add edgnext_small_rw weights trained with swin like recipe. Better than original 'small' but not the recent 'USI' distilled weights. --- timm/models/edgenext.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index 97971ba6..29316b9a 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -46,10 +46,13 @@ default_cfgs = dict( # url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_small.pth"), edgenext_small=_cfg( # USI weights url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth", - crop_pct=0.95 + crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, ), - edgenext_small_rw=_cfg(), + edgenext_small_rw=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth', + test_input_size=(3, 320, 320), test_crop_pct=1.0, + ), ) From 1c5cb819f94834b73843e1f088a3b2b12f550680 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 22:05:56 -0700 Subject: [PATCH 40/45] bump version to 0.6.3 before merge --- timm/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/version.py b/timm/version.py index 3e8e43bd..7165c7fa 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.6.2.dev0' +__version__ = '0.6.3.dev0' From a8e34051c1de050421abd50fbc1201d125a50fe7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 23:07:43 -0700 Subject: [PATCH 41/45] Unbreak gamma remap impacting beit checkpoint load, version bump to 0.6.4 --- timm/models/deit.py | 4 +++- timm/models/vision_transformer.py | 4 ++-- timm/version.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/timm/models/deit.py b/timm/models/deit.py index a2f43b91..8cb36bd6 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -10,6 +10,8 @@ Modifications copyright 2021, Ross Wightman """ # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. +from functools import partial + import torch from torch import nn as nn @@ -177,7 +179,7 @@ def _create_deit(variant, pretrained=False, distilled=False, **kwargs): model_cls = VisionTransformerDistilled if distilled else VisionTransformer model = build_model_with_cfg( model_cls, variant, pretrained, - pretrained_filter_fn=checkpoint_filter_fn, + pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True), **kwargs) return model diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 022052d0..c92c22a3 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -626,7 +626,7 @@ def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): return posemb -def checkpoint_filter_fn(state_dict, model): +def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False): """ convert patch embedding weight from manual patchify + linear proj to conv""" import re out_dict = {} @@ -647,7 +647,7 @@ def checkpoint_filter_fn(state_dict, model): getattr(model, 'num_prefix_tokens', 1), model.patch_embed.grid_size ) - elif 'gamma_' in k: + elif adapt_layer_scale and 'gamma_' in k: # remap layer-scale gamma into sub-module (deit3 models) k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k) elif 'pre_logits' in k: diff --git a/timm/version.py b/timm/version.py index 7165c7fa..02f8497c 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.6.3.dev0' +__version__ = '0.6.4' From a45b4bce9a022d413eb27a342a7a9997580bb4aa Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 8 Jul 2022 10:53:27 -0700 Subject: [PATCH 42/45] x and xx small edgenext models do benefit from larger test input size --- timm/models/edgenext.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index 29316b9a..a81bd9b0 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -39,9 +39,11 @@ def _cfg(url='', **kwargs): default_cfgs = dict( edgenext_xx_small=_cfg( - url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth"), + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), edgenext_x_small=_cfg( - url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth"), + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), # edgenext_small=_cfg( # url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_small.pth"), edgenext_small=_cfg( # USI weights From 66393d472fc0d5c428635a91032255b36c064499 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 8 Jul 2022 12:21:23 -0700 Subject: [PATCH 43/45] Update README.md --- README.md | 44 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 4c39c692..9d6c9394 100644 --- a/README.md +++ b/README.md @@ -13,16 +13,50 @@ ## Sponsors -A big thank you to my [GitHub Sponsors](https://github.com/sponsors/rwightman) for their support! - -In addition to the sponsors at the link above, I've received hardware and/or cloud resources from +Thanks to the following for hardware support: +* TPU Research Cloud (TRC) (https://sites.research.google/trc/about/) * Nvidia (https://www.nvidia.com/en-us/) -* TFRC (https://www.tensorflow.org/tfrc) -I'm fortunate to be able to dedicate significant time and money of my own supporting this and other open source projects. However, as the projects increase in scope, outside support is needed to continue with the current trajectory of cloud services, hardware, and electricity costs. +And a big thanks to all GitHub sponsors who helped with some of my costs before I joined Hugging Face. ## What's New +### July 8, 2022 +More models, more fixes +* Official research models (w/ weights) added: + * EdgeNeXt from (https://github.com/mmaaz60/EdgeNeXt) + * MobileViT-V2 from (https://github.com/apple/ml-cvnets) + * DeiT III (Revenge of the ViT) from (https://github.com/facebookresearch/deit) +* My own models: + * Small `ResNet` defs added by request with 1 block repeats for both basic and bottleneck (resnet10 and resnet14) + * `CspNet` refactored with dataclass config, simplified CrossStage3 (`cs3`) option. These are closer to YOLO-v5+ backbone defs. + * More relative position vit fiddling. Two `srelpos` (shared relative position) models trained, and a medium w/ class token. + * Add an alternate downsample mode to EdgeNeXt and train a `small` model. Better than original small, but not their new USI trained weights. +* My own model weight results (all ImageNet-1k training) + * `resnet10t` - 66.5 @ 176, 68.3 @ 224 + * `resnet14t` - 71.3 @ 176, 72.3 @ 224 + * `resnetaa50` - 80.6 @ 224 , 81.6 @ 288 + * `darknet53` - 80.0 @ 256, 80.5 @ 288 + * `cs3darknet_m` - 77.0 @ 256, 77.6 @ 288 + * `cs3darknet_focus_m` - 76.7 @ 256, 77.3 @ 288 + * `cs3darknet_l` - 80.4 @ 256, 80.9 @ 288 + * `cs3darknet_focus_l` - 80.3 @ 256, 80.9 @ 288 + * `vit_srelpos_small_patch16_224` - 81.1 @ 224, 82.1 @ 320 + * `vit_srelpos_medium_patch16_224` - 82.3 @ 224, 83.1 @ 320 + * `vit_relpos_small_patch16_cls_224` - 82.6 @ 224, 83.6 @ 320 + * `edgnext_small_rw` - 79.6 @ 224, 80.4 @ 320 +* `cs3`, `darknet`, and `vit_*relpos` weights above all trained on TPU thanks to TRC program! Rest trained on overheating GPUs. +* Hugging Face Hub support fixes verified, demo notebook TBA +* Pretrained weights / configs can be loaded externally (ie from local disk) w/ support for head adaptation. +* Add support to change image extensions scanned by `timm` datasets/parsers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103) +* Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases. + * a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges. + * previous impl exists as `LayerNormExp2d` in `models/layers/norm.py` +* Numerous bug fixes +* Currently testing for imminent PyPi 0.6.x release +* LeViT pretraining of larger models still a WIP, they don't train well / easily without distillation. Time to add distill support (finally)? +* ImageNet-22k weight training + finetune ongoing, work on multi-weight support (slowly) chugging along (there are a LOT of weights, sigh) ... + ### May 13, 2022 * Official Swin-V2 models and weights added from (https://github.com/microsoft/Swin-Transformer). Cleaned up to support torchscript. * Some refactoring for existing `timm` Swin-V2-CR impl, will likely do a bit more to bring parts closer to official and decide whether to merge some aspects. From 2898cf6e41357b6e79229547ad191ca74299f5d2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 10 Jul 2022 16:43:23 -0700 Subject: [PATCH 44/45] version 0.6.5 for pypi release --- timm/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/version.py b/timm/version.py index 02f8497c..e2f45ae2 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.6.4' +__version__ = '0.6.5' From 4e7ffe50435c3a40a074a0e2bb17663f89948ff3 Mon Sep 17 00:00:00 2001 From: Muhammad Maaz Date: Tue, 12 Jul 2022 05:08:34 +0400 Subject: [PATCH 45/45] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 9d6c9394..e4c058f1 100644 --- a/README.md +++ b/README.md @@ -383,6 +383,7 @@ A full version of the list below with source links can be found in the [document * DenseNet - https://arxiv.org/abs/1608.06993 * DLA - https://arxiv.org/abs/1707.06484 * DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629 +* EdgeNeXt - https://arxiv.org/abs/2206.10589 * EfficientNet (MBConvNet Family) * EfficientNet NoisyStudent (B0-B7, L2) - https://arxiv.org/abs/1911.04252 * EfficientNet AdvProp (B0-B8) - https://arxiv.org/abs/1911.09665