From 8e11da0ce38e96b17760f8aa048c01f7c75ece3b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 23 Sep 2021 15:42:57 -0700 Subject: [PATCH 01/20] Add experimental RegNetZ(ish) models for training / perf trials. --- timm/models/byobnet.py | 77 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index edce355a..50ad1e88 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -137,6 +137,17 @@ default_cfgs = { 'gcresnext50ts': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth', first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), + + # experimental models + 'regnetz_b': _cfg( + url='', + input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), + 'regnetz_c': _cfg( + url='', + input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), + 'regnetz_d': _cfg( + url='', + input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), } @@ -489,6 +500,51 @@ model_cfgs = dict( act_layer='silu', attn_layer='gca', ), + + # experimental models, closer to a RegNetZ than a ResNet. Similar to EfficientNets but w/ groups instead of DW + regnetz_b=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=192, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)), + ByoBlockCfg(type='bottle', d=6, c=384, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)), + ByoBlockCfg(type='bottle', d=12, c=768, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)), + ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)), + ), + stem_chs=32, + stem_pool='', + num_features=1792, + act_layer='silu', + attn_layer='se', + attn_kwargs=dict(rd_ratio=0.25), + ), + regnetz_c=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=128, s=2, gs=16, br=0.5, block_kwargs=dict(linear_out=True)), + ByoBlockCfg(type='bottle', d=6, c=512, s=2, gs=32, br=0.25, block_kwargs=dict(linear_out=True)), + ByoBlockCfg(type='bottle', d=12, c=768, s=2, gs=32, br=0.25, block_kwargs=dict(linear_out=True)), + ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)), + ), + stem_chs=32, + stem_pool='', + num_features=1792, + act_layer='silu', + attn_layer='se', + attn_kwargs=dict(rd_ratio=0.25), + ), + regnetz_d=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=64, br=0.25, block_kwargs=dict(linear_out=True)), + ByoBlockCfg(type='bottle', d=6, c=512, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)), + ByoBlockCfg(type='bottle', d=12, c=768, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)), + ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)), + ), + stem_chs=128, + stem_type='quad', + stem_pool='', + num_features=1792, + act_layer='silu', + attn_layer='se', + attn_kwargs=dict(rd_ratio=0.25), + ), ) @@ -678,6 +734,27 @@ def gcresnext50ts(pretrained=False, **kwargs): return _create_byobnet('gcresnext50ts', pretrained=pretrained, **kwargs) +@register_model +def regnetz_b(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('regnetz_b', pretrained=pretrained, **kwargs) + + +@register_model +def regnetz_c(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('regnetz_c', pretrained=pretrained, **kwargs) + + +@register_model +def regnetz_d(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('regnetz_d', pretrained=pretrained, **kwargs) + + def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]: if not isinstance(stage_blocks_cfg, Sequence): stage_blocks_cfg = (stage_blocks_cfg,) From da06cc61d4081925dea57864e73926aab405cfaa Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 23 Sep 2021 15:43:22 -0700 Subject: [PATCH 02/20] ResNetV2 seems to work best without zero_init residual --- timm/models/resnetv2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 2ff4da8c..2b5121a2 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -344,7 +344,7 @@ class ResNetV2(nn.Module): num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True, act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32), - drop_rate=0., drop_path_rate=0., zero_init_last=True): + drop_rate=0., drop_path_rate=0., zero_init_last=False): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate From 515121cca1545a3a8ac3c077579f970bcfce00da Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 23 Sep 2021 15:43:48 -0700 Subject: [PATCH 03/20] Use reshape instead of view in std_conv, causing issues in recent PyTorch in channels_last --- timm/models/layers/std_conv.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/timm/models/layers/std_conv.py b/timm/models/layers/std_conv.py index 3ccc16e1..d896ba5c 100644 --- a/timm/models/layers/std_conv.py +++ b/timm/models/layers/std_conv.py @@ -41,7 +41,7 @@ class StdConv2d(nn.Conv2d): def forward(self, x): weight = F.batch_norm( - self.weight.view(1, self.out_channels, -1), None, None, + self.weight.reshape(1, self.out_channels, -1), None, None, training=True, momentum=0., eps=self.eps).reshape_as(self.weight) x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return x @@ -67,7 +67,7 @@ class StdConv2dSame(nn.Conv2d): if self.same_pad: x = pad_same(x, self.kernel_size, self.stride, self.dilation) weight = F.batch_norm( - self.weight.view(1, self.out_channels, -1), None, None, + self.weight.reshape(1, self.out_channels, -1), None, None, training=True, momentum=0., eps=self.eps).reshape_as(self.weight) x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return x @@ -96,7 +96,7 @@ class ScaledStdConv2d(nn.Conv2d): def forward(self, x): weight = F.batch_norm( - self.weight.view(1, self.out_channels, -1), None, None, + self.weight.reshape(1, self.out_channels, -1), None, None, weight=(self.gain * self.scale).view(-1), training=True, momentum=0., eps=self.eps).reshape_as(self.weight) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) @@ -127,7 +127,7 @@ class ScaledStdConv2dSame(nn.Conv2d): if self.same_pad: x = pad_same(x, self.kernel_size, self.stride, self.dilation) weight = F.batch_norm( - self.weight.view(1, self.out_channels, -1), None, None, + self.weight.reshape(1, self.out_channels, -1), None, None, weight=(self.gain * self.scale).view(-1), training=True, momentum=0., eps=self.eps).reshape_as(self.weight) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) From f8a63a3b7173edcf1086ab0a3c4d9226bfc40569 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 23 Sep 2021 15:44:38 -0700 Subject: [PATCH 04/20] Add worker_init_fn to loader for numpy seed per worker --- timm/data/loader.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/timm/data/loader.py b/timm/data/loader.py index 99cf132f..7d5aa1e5 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -125,6 +125,12 @@ class PrefetchLoader: self.loader.collate_fn.mixup_enabled = x +def _worker_init(worker_id): + worker_info = torch.utils.data.get_worker_info() + assert worker_info.id == worker_id + np.random.seed(worker_info.seed % (2**32-1)) + + def create_loader( dataset, input_size, @@ -202,7 +208,6 @@ def create_loader( collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate loader_class = torch.utils.data.DataLoader - if use_multi_epochs_loader: loader_class = MultiEpochsDataLoader @@ -214,6 +219,7 @@ def create_loader( collate_fn=collate_fn, pin_memory=pin_memory, drop_last=is_training, + worker_init_fn=_worker_init, persistent_workers=persistent_workers) try: loader = loader_class(dataset, **loader_args) From 5d6983c4622ece75e7b9f7f05e1114be8f54deb9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 23 Sep 2021 15:45:17 -0700 Subject: [PATCH 05/20] Batch validate a list of files if model is a text file with model per line --- validate.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/validate.py b/validate.py index ab5b644f..9b2c0f7e 100755 --- a/validate.py +++ b/validate.py @@ -296,6 +296,11 @@ def main(): model_names = list_models(args.model) model_cfgs = [(n, '') for n in model_names] + if not model_cfgs and os.path.isfile(args.model): + with open(args.model) as f: + model_names = [line.rstrip() for line in f] + model_cfgs = [(n, None) for n in model_names if n] + if len(model_cfgs): results_file = args.results_file or './results-all.csv' _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) From 0387e6057e191899f20ca2bf48aa2039ece910cc Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 23 Sep 2021 15:45:39 -0700 Subject: [PATCH 06/20] Update binary cross ent impl to use thresholding as an option (convert soft targets from mixup/cutmix to 0, 1) --- timm/loss/__init__.py | 2 +- timm/loss/binary_cross_entropy.py | 50 +++++++++++++++++++++++-------- timm/loss/cross_entropy.py | 16 +++++----- train.py | 8 +++-- 4 files changed, 51 insertions(+), 25 deletions(-) diff --git a/timm/loss/__init__.py b/timm/loss/__init__.py index a74bcb88..ea7f15f2 100644 --- a/timm/loss/__init__.py +++ b/timm/loss/__init__.py @@ -1,4 +1,4 @@ from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel -from .binary_cross_entropy import DenseBinaryCrossEntropy +from .binary_cross_entropy import BinaryCrossEntropy from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from .jsd import JsdCrossEntropy diff --git a/timm/loss/binary_cross_entropy.py b/timm/loss/binary_cross_entropy.py index 6da04dba..ed76c1e8 100644 --- a/timm/loss/binary_cross_entropy.py +++ b/timm/loss/binary_cross_entropy.py @@ -1,23 +1,47 @@ +""" Binary Cross Entropy w/ a few extras + +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F -class DenseBinaryCrossEntropy(nn.Module): - """ BCE using one-hot from dense targets w/ label smoothing +class BinaryCrossEntropy(nn.Module): + """ BCE with optional one-hot from dense targets, label smoothing, thresholding NOTE for experiments comparing CE to BCE /w label smoothing, may remove """ - def __init__(self, smoothing=0.1): - super(DenseBinaryCrossEntropy, self).__init__() + def __init__( + self, smoothing=0.1, target_threshold: Optional[float] = None, weight: Optional[torch.Tensor] = None, + reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None): + super(BinaryCrossEntropy, self).__init__() assert 0. <= smoothing < 1.0 self.smoothing = smoothing - self.bce = nn.BCEWithLogitsLoss() + self.target_threshold = target_threshold + self.reduction = reduction + self.register_buffer('weight', weight) + self.register_buffer('pos_weight', pos_weight) - def forward(self, x, target): - num_classes = x.shape[-1] - off_value = self.smoothing / num_classes - on_value = 1. - self.smoothing + off_value - target = target.long().view(-1, 1) - target = torch.full( - (target.size()[0], num_classes), off_value, device=x.device, dtype=x.dtype).scatter_(1, target, on_value) - return self.bce(x, target) + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + assert x.shape[0] == target.shape[0] + if target.shape != x.shape: + # NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse + num_classes = x.shape[-1] + # FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ + off_value = self.smoothing / num_classes + on_value = 1. - self.smoothing + off_value + target = target.long().view(-1, 1) + target = torch.full( + (target.size()[0], num_classes), + off_value, + device=x.device, dtype=x.dtype).scatter_(1, target, on_value) + if self.target_threshold is not None: + # Make target 0, or 1 if threshold set + target = target.gt(self.target_threshold).to(dtype=target.dtype) + return F.binary_cross_entropy_with_logits( + x, target, + self.weight, + pos_weight=self.pos_weight, + reduction=self.reduction) diff --git a/timm/loss/cross_entropy.py b/timm/loss/cross_entropy.py index 60bef646..85198107 100644 --- a/timm/loss/cross_entropy.py +++ b/timm/loss/cross_entropy.py @@ -1,23 +1,23 @@ +""" Cross Entropy w/ smoothing or soft targets + +Hacked together by / Copyright 2021 Ross Wightman +""" + import torch import torch.nn as nn import torch.nn.functional as F class LabelSmoothingCrossEntropy(nn.Module): - """ - NLL loss with label smoothing. + """ NLL loss with label smoothing. """ def __init__(self, smoothing=0.1): - """ - Constructor for the LabelSmoothing module. - :param smoothing: label smoothing factor - """ super(LabelSmoothingCrossEntropy, self).__init__() assert smoothing < 1.0 self.smoothing = smoothing self.confidence = 1. - smoothing - def forward(self, x, target): + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: logprobs = F.log_softmax(x, dim=-1) nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) nll_loss = nll_loss.squeeze(1) @@ -31,6 +31,6 @@ class SoftTargetCrossEntropy(nn.Module): def __init__(self): super(SoftTargetCrossEntropy, self).__init__() - def forward(self, x, target): + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) return loss.mean() diff --git a/train.py b/train.py index 3943c7d0..55aba416 100755 --- a/train.py +++ b/train.py @@ -190,6 +190,8 @@ parser.add_argument('--jsd-loss', action='store_true', default=False, help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') parser.add_argument('--bce-loss', action='store_true', default=False, help='Enable BCE loss w/ Mixup/CutMix use.') +parser.add_argument('--bce-target-thresh', type=float, default=None, + help='Threshold for binarizing softened BCE targets (default: None, disabled)') parser.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='pixel', @@ -459,7 +461,7 @@ def main(): else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") - model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 + model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.dist_bn) # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch @@ -558,12 +560,12 @@ def main(): elif mixup_active: # smoothing is handled with mixup target transform which outputs sparse, soft targets if args.bce_loss: - train_loss_fn = nn.BCEWithLogitsLoss() + train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh) else: train_loss_fn = SoftTargetCrossEntropy() elif args.smoothing: if args.bce_loss: - train_loss_fn = DenseBinaryCrossEntropy(smoothing=args.smoothing) + train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh) else: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: From 6478bcd02c8af1b238d6ecb7c2eba247e91b1246 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 26 Sep 2021 14:54:17 -0700 Subject: [PATCH 07/20] Fix regnetz_d conv layer name, use inception mean/std --- timm/models/byobnet.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 50ad1e88..d00aeb32 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -141,13 +141,16 @@ default_cfgs = { # experimental models 'regnetz_b': _cfg( url='', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), 'regnetz_c': _cfg( url='', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), 'regnetz_d': _cfg( url='', - input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), } From 80075b0b8a4372c6be82ff4dc88517572d91b94a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 28 Sep 2021 16:37:45 -0700 Subject: [PATCH 08/20] Add worker_seeding arg to allow selecting old vs updated data loader worker seed for (old) experiment repeatability --- timm/data/loader.py | 25 ++++++++++++++++++++----- train.py | 5 ++++- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/timm/data/loader.py b/timm/data/loader.py index 7d5aa1e5..a02399a3 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -3,8 +3,11 @@ Prefetcher and Fast Collate inspired by NVIDIA APEX example at https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf -Hacked together by / Copyright 2020 Ross Wightman +Hacked together by / Copyright 2021 Ross Wightman """ +import random +from functools import partial +from typing import Callable import torch.utils.data import numpy as np @@ -125,10 +128,20 @@ class PrefetchLoader: self.loader.collate_fn.mixup_enabled = x -def _worker_init(worker_id): +def _worker_init(worker_id, worker_seeding='all'): worker_info = torch.utils.data.get_worker_info() assert worker_info.id == worker_id - np.random.seed(worker_info.seed % (2**32-1)) + if isinstance(worker_seeding, Callable): + seed = worker_seeding(worker_info) + random.seed(seed) + torch.manual_seed(seed) + np.random.seed(seed % (2 ** 32 - 1)) + else: + assert worker_seeding in ('all', 'part') + # random / torch seed already called in dataloader iter class w/ worker_info.seed + # to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed) + if worker_seeding == 'all': + np.random.seed(worker_info.seed % (2 ** 32 - 1)) def create_loader( @@ -162,6 +175,7 @@ def create_loader( tf_preprocessing=False, use_multi_epochs_loader=False, persistent_workers=True, + worker_seeding='all', ): re_num_splits = 0 if re_split: @@ -219,8 +233,9 @@ def create_loader( collate_fn=collate_fn, pin_memory=pin_memory, drop_last=is_training, - worker_init_fn=_worker_init, - persistent_workers=persistent_workers) + worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding), + persistent_workers=persistent_workers + ) try: loader = loader_class(dataset, **loader_args) except TypeError as e: diff --git a/train.py b/train.py index 55aba416..d95611ad 100755 --- a/train.py +++ b/train.py @@ -252,6 +252,8 @@ parser.add_argument('--model-ema-decay', type=float, default=0.9998, # Misc parser.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') +parser.add_argument('--worker-seeding', type=str, default='all', + help='worker seed mode (default: all)') parser.add_argument('--log-interval', type=int, default=50, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', @@ -535,7 +537,8 @@ def main(): distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, - use_multi_epochs_loader=args.use_multi_epochs_loader + use_multi_epochs_loader=args.use_multi_epochs_loader, + worker_seeding=args.worker_seeding, ) loader_eval = create_loader( From b81e79aae9579a6868e139780ad064d73693c2d3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 28 Sep 2021 16:38:41 -0700 Subject: [PATCH 09/20] Fix bottleneck attn transpose typo, hopefully these train better now.. --- timm/models/layers/bottleneck_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index c0c619cc..bf6af675 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -122,7 +122,7 @@ class BottleneckAttn(nn.Module): attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W attn_out = attn_logits.softmax(dim=-1) - attn_out = (attn_out @ v).transpose(1, 2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W + attn_out = (attn_out @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W attn_out = self.pool(attn_out) return attn_out From 0ca687f224a8634a65cd580c5d5bbe86ef16bbdb Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 30 Sep 2021 21:49:38 -0700 Subject: [PATCH 10/20] Make 'regnetz' model experiments closer to actual RegNetZ, bottleneck expansion, expand from in_chs, no shortcut on stride 2, tweak model sizes --- timm/models/byoanet.py | 52 ++++++++- timm/models/byobnet.py | 257 +++++++++++++++++++---------------------- 2 files changed, 171 insertions(+), 138 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 6558de35..5c7be0d6 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -57,8 +57,14 @@ default_cfgs = { input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'lambda_resnet26t': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_256-b040fce6.pth', + url='', + min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), + 'lambda_resnet50ts': _cfg( + url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), + 'lambda_resnet26rt_256': _cfg( + url='', + fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), } @@ -198,6 +204,33 @@ model_cfgs = dict( self_attn_layer='lambda', self_attn_kwargs=dict(r=9) ), + lambda_resnet50ts=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + act_layer='silu', + self_attn_layer='lambda', + self_attn_kwargs=dict(r=9) + ), + lambda_resnet26rt_256=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + self_attn_layer='lambda', + self_attn_kwargs=dict(r=None) + ), ) @@ -275,6 +308,21 @@ def eca_halonext26ts(pretrained=False, **kwargs): @register_model def lambda_resnet26t(pretrained=False, **kwargs): - """ Lambda-ResNet-26T. Lambda layers in last two stages. + """ Lambda-ResNet-26-T. Lambda layers w/ conv pos in last two stages. """ return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs) + + +@register_model +def lambda_resnet50ts(pretrained=False, **kwargs): + """ Lambda-ResNet-50-TS. SiLU act. Lambda layers w/ conv pos in last two stages. + """ + return _create_byoanet('lambda_resnet50ts', pretrained=pretrained, **kwargs) + + +@register_model +def lambda_resnet26rt_256(pretrained=False, **kwargs): + """ Lambda-ResNet-26-R-T. Lambda layers w/ rel pos embed in last two stages. + """ + kwargs.setdefault('img_size', 256) + return _create_byoanet('lambda_resnet26rt_256', pretrained=pretrained, **kwargs) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index d00aeb32..515f2073 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -51,6 +51,16 @@ def _cfg(url='', **kwargs): } +def _cfgr(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), + 'crop_pct': 0.9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc', + **kwargs + } + + default_cfgs = { # GPU-Efficient (ResNet) weights 'gernet_s': _cfg( @@ -92,65 +102,50 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet51q_ra2-d47dcc76.pth', first_conv='stem.conv1', input_size=(3, 256, 256), pool_size=(8, 8), test_input_size=(3, 288, 288), crop_pct=1.0), - 'resnet61q': _cfg( + 'resnet61q': _cfgr( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet61q_ra2-6afc536c.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), - test_input_size=(3, 288, 288), crop_pct=1.0, interpolation='bicubic'), - - 'resnext26ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnext26ts_256_ra2-8bbd9106.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'gcresnext26ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext26ts_256-e414378b.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'seresnext26ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnext26ts_256-6f0d74a3.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'eca_resnext26ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnext26ts_256-5a1d030f.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'bat_resnext26ts': _cfg( + test_input_size=(3, 288, 288), crop_pct=1.0), + + 'resnext26ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnext26ts_256_ra2-8bbd9106.pth'), + 'gcresnext26ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext26ts_256-e414378b.pth'), + 'seresnext26ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnext26ts_256-6f0d74a3.pth'), + 'eca_resnext26ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnext26ts_256-5a1d030f.pth'), + 'bat_resnext26ts': _cfgr( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/bat_resnext26ts_256-fa6fd595.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic', min_input_size=(3, 256, 256)), - 'resnet32ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet32ts_256-aacf5250.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'resnet33ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet33ts_256-e91b09a4.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'gcresnet33ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet33ts_256-0e0cd345.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'seresnet33ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnet33ts_256-f8ad44d9.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'eca_resnet33ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnet33ts_256-8f98face.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - - 'gcresnet50t': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet50t_256-96374d1c.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - - 'gcresnext50ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - - # experimental models - 'regnetz_b': _cfg( + 'resnet32ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet32ts_256-aacf5250.pth'), + 'resnet33ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet33ts_256-e91b09a4.pth'), + 'gcresnet33ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet33ts_256-0e0cd345.pth'), + 'seresnet33ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnet33ts_256-f8ad44d9.pth'), + 'eca_resnet33ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnet33ts_256-8f98face.pth'), + + 'gcresnet50t': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet50t_256-96374d1c.pth'), + + 'gcresnext50ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth'), + + # experimental models, likely to change ot be removed + 'regnetz_b': _cfgr( url='', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'regnetz_c': _cfg( + input_size=(3, 224, 224), pool_size=(7, 7), first_conv='stem.conv'), + 'regnetz_c': _cfgr( url='', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'regnetz_d': _cfg( + imean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), first_conv='stem.conv'), + 'regnetz_d': _cfgr( url='', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), } @@ -507,46 +502,52 @@ model_cfgs = dict( # experimental models, closer to a RegNetZ than a ResNet. Similar to EfficientNets but w/ groups instead of DW regnetz_b=ByoModelCfg( blocks=( - ByoBlockCfg(type='bottle', d=2, c=192, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)), - ByoBlockCfg(type='bottle', d=6, c=384, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)), - ByoBlockCfg(type='bottle', d=12, c=768, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)), - ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)), + ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3), + ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3), + ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=3), + ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=3), ), stem_chs=32, stem_pool='', - num_features=1792, + downsample='', + num_features=1536, act_layer='silu', attn_layer='se', attn_kwargs=dict(rd_ratio=0.25), + block_kwargs=dict(bottle_in=True, linear_out=True), ), regnetz_c=ByoModelCfg( blocks=( - ByoBlockCfg(type='bottle', d=2, c=128, s=2, gs=16, br=0.5, block_kwargs=dict(linear_out=True)), - ByoBlockCfg(type='bottle', d=6, c=512, s=2, gs=32, br=0.25, block_kwargs=dict(linear_out=True)), - ByoBlockCfg(type='bottle', d=12, c=768, s=2, gs=32, br=0.25, block_kwargs=dict(linear_out=True)), - ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)), + ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=4), + ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=4), + ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=4), + ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=4), ), stem_chs=32, stem_pool='', - num_features=1792, + downsample='', + num_features=1536, act_layer='silu', attn_layer='se', attn_kwargs=dict(rd_ratio=0.25), + block_kwargs=dict(bottle_in=True, linear_out=True), ), regnetz_d=ByoModelCfg( blocks=( - ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=64, br=0.25, block_kwargs=dict(linear_out=True)), - ByoBlockCfg(type='bottle', d=6, c=512, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)), - ByoBlockCfg(type='bottle', d=12, c=768, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)), - ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)), + ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=32, br=4), + ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=32, br=4), + ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=32, br=4), + ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=32, br=4), ), - stem_chs=128, - stem_type='quad', + stem_chs=64, + stem_type='tiered', stem_pool='', + downsample='', num_features=1792, act_layer='silu', attn_layer='se', attn_kwargs=dict(rd_ratio=0.25), + block_kwargs=dict(bottle_in=True, linear_out=True), ), ) @@ -802,11 +803,17 @@ class DownsampleAvg(nn.Module): return self.conv(self.pool(x)) -def create_downsample(downsample_type, layers: LayerFn, **kwargs): - if downsample_type == 'avg': - return DownsampleAvg(**kwargs) +def create_shortcut(downsample_type, layers: LayerFn, in_chs, out_chs, stride, dilation, **kwargs): + assert downsample_type in ('avg', 'conv1x1', '') + if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: + if not downsample_type: + return None # no shortcut + elif downsample_type == 'avg': + return DownsampleAvg(in_chs, out_chs, stride=stride, dilation=dilation[0], **kwargs) + else: + return layers.conv_norm_act(in_chs, out_chs, kernel_size=1, stride=stride, dilation=dilation[0], **kwargs) else: - return layers.conv_norm_act(kwargs.pop('in_chs'), kwargs.pop('out_chs'), kernel_size=1, **kwargs) + return nn.Identity() # identity shortcut class BasicBlock(nn.Module): @@ -822,12 +829,9 @@ class BasicBlock(nn.Module): mid_chs = make_divisible(out_chs * bottle_ratio) groups = num_groups(group_size, mid_chs) - if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: - self.shortcut = create_downsample( - downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], - apply_act=False, layers=layers) - else: - self.shortcut = nn.Identity() + self.shortcut = create_shortcut( + downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation, + apply_act=False, layers=layers) self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0]) self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) @@ -838,23 +842,21 @@ class BasicBlock(nn.Module): self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last: + if zero_init_last and self.shortcut is not None: nn.init.zeros_(self.conv2_kxk.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): attn.reset_parameters() def forward(self, x): - shortcut = self.shortcut(x) - - # residual path + shortcut = x x = self.conv1_kxk(x) x = self.conv2_kxk(x) x = self.attn(x) x = self.drop_path(x) - - x = self.act(x + shortcut) - return x + if self.shortcut is not None: + x = x + self.shortcut(shortcut) + return self.act(x) class BottleneckBlock(nn.Module): @@ -862,24 +864,18 @@ class BottleneckBlock(nn.Module): """ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, - downsample='avg', attn_last=False, linear_out=False, extra_conv=False, layers: LayerFn = None, - drop_block=None, drop_path_rate=0.): + downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False, + layers: LayerFn = None, drop_block=None, drop_path_rate=0.): super(BottleneckBlock, self).__init__() layers = layers or LayerFn() - mid_chs = make_divisible(out_chs * bottle_ratio) + mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio) groups = num_groups(group_size, mid_chs) - if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: - self.shortcut = create_downsample( - downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], - apply_act=False, layers=layers) - else: - self.shortcut = nn.Identity() + self.shortcut = create_shortcut( + downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation, + apply_act=False, layers=layers) self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) - self.conv2_kxk = layers.conv_norm_act( - mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], - groups=groups, drop_block=drop_block) self.conv2_kxk = layers.conv_norm_act( mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], groups=groups, drop_block=drop_block) @@ -895,15 +891,14 @@ class BottleneckBlock(nn.Module): self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last: + if zero_init_last and self.shortcut is not None: nn.init.zeros_(self.conv3_1x1.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): attn.reset_parameters() def forward(self, x): - shortcut = self.shortcut(x) - + shortcut = x x = self.conv1_1x1(x) x = self.conv2_kxk(x) x = self.conv2b_kxk(x) @@ -911,9 +906,9 @@ class BottleneckBlock(nn.Module): x = self.conv3_1x1(x) x = self.attn_last(x) x = self.drop_path(x) - - x = self.act(x + shortcut) - return x + if self.shortcut is not None: + x = x + self.shortcut(shortcut) + return self.act(x) class DarkBlock(nn.Module): @@ -935,12 +930,9 @@ class DarkBlock(nn.Module): mid_chs = make_divisible(out_chs * bottle_ratio) groups = num_groups(group_size, mid_chs) - if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: - self.shortcut = create_downsample( - downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], - apply_act=False, layers=layers) - else: - self.shortcut = nn.Identity() + self.shortcut = create_shortcut( + downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation, + apply_act=False, layers=layers) self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) @@ -952,22 +944,22 @@ class DarkBlock(nn.Module): self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last: + if zero_init_last and self.shortcut is not None: nn.init.zeros_(self.conv2_kxk.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): attn.reset_parameters() def forward(self, x): - shortcut = self.shortcut(x) - + shortcut = x x = self.conv1_1x1(x) x = self.attn(x) x = self.conv2_kxk(x) x = self.attn_last(x) x = self.drop_path(x) - x = self.act(x + shortcut) - return x + if self.shortcut is not None: + x = x + self.shortcut(shortcut) + return self.act(x) class EdgeBlock(nn.Module): @@ -988,12 +980,9 @@ class EdgeBlock(nn.Module): mid_chs = make_divisible(out_chs * bottle_ratio) groups = num_groups(group_size, mid_chs) - if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: - self.shortcut = create_downsample( - downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], - apply_act=False, layers=layers) - else: - self.shortcut = nn.Identity() + self.shortcut = create_shortcut( + downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation, + apply_act=False, layers=layers) self.conv1_kxk = layers.conv_norm_act( in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], @@ -1005,22 +994,22 @@ class EdgeBlock(nn.Module): self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last: + if zero_init_last and self.shortcut is not None: nn.init.zeros_(self.conv2_1x1.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): attn.reset_parameters() def forward(self, x): - shortcut = self.shortcut(x) - + shortcut = x x = self.conv1_kxk(x) x = self.attn(x) x = self.conv2_1x1(x) x = self.attn_last(x) x = self.drop_path(x) - x = self.act(x + shortcut) - return x + if self.shortcut is not None: + x = x + self.shortcut(shortcut) + return self.act(x) class RepVggBlock(nn.Module): @@ -1065,8 +1054,7 @@ class RepVggBlock(nn.Module): x = self.drop_path(x) # not in the paper / official impl, experimental x = x + identity x = self.attn(x) # no attn in the paper / official impl, experimental - x = self.act(x) - return x + return self.act(x) class SelfAttnBlock(nn.Module): @@ -1074,19 +1062,16 @@ class SelfAttnBlock(nn.Module): """ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, - downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None, - layers: LayerFn = None, drop_block=None, drop_path_rate=0.): + downsample='avg', extra_conv=False, linear_out=False, bottle_in=False, post_attn_na=True, + feat_size=None, layers: LayerFn = None, drop_block=None, drop_path_rate=0.): super(SelfAttnBlock, self).__init__() assert layers is not None - mid_chs = make_divisible(out_chs * bottle_ratio) + mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio) groups = num_groups(group_size, mid_chs) - if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: - self.shortcut = create_downsample( - downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], - apply_act=False, layers=layers) - else: - self.shortcut = nn.Identity() + self.shortcut = create_shortcut( + downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation, + apply_act=False, layers=layers) self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) if extra_conv: @@ -1105,7 +1090,7 @@ class SelfAttnBlock(nn.Module): self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last: + if zero_init_last and self.shortcut is not None: nn.init.zeros_(self.conv3_1x1.bn.weight) if hasattr(self.self_attn, 'reset_parameters'): self.self_attn.reset_parameters() From d657e2cc0bda62327489b0246b4a795fa69570a6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 30 Sep 2021 21:54:42 -0700 Subject: [PATCH 11/20] Remove dead code line from efficientnet --- timm/models/efficientnet_blocks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index b43f38f5..b1fec449 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -277,7 +277,6 @@ class EdgeResidual(nn.Module): mid_chs = make_divisible(force_in_chs * exp_ratio) else: mid_chs = make_divisible(in_chs * exp_ratio) - has_se = se_layer is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_path_rate = drop_path_rate From b49630a1382507bc13abc16483a3bf533c4d23da Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 30 Sep 2021 22:45:09 -0700 Subject: [PATCH 12/20] Add relative pos embed option to LambdaLayer, fix last transpose/reshape. --- timm/models/byoanet.py | 8 ++--- timm/models/layers/lambda_layer.py | 49 ++++++++++++++++++++++++------ 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 5c7be0d6..056813ef 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -62,7 +62,7 @@ default_cfgs = { 'lambda_resnet50ts': _cfg( url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), - 'lambda_resnet26rt_256': _cfg( + 'lambda_resnet26rpt_256': _cfg( url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), } @@ -218,7 +218,7 @@ model_cfgs = dict( self_attn_layer='lambda', self_attn_kwargs=dict(r=9) ), - lambda_resnet26rt_256=ByoModelCfg( + lambda_resnet26rpt_256=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), @@ -321,8 +321,8 @@ def lambda_resnet50ts(pretrained=False, **kwargs): @register_model -def lambda_resnet26rt_256(pretrained=False, **kwargs): +def lambda_resnet26rpt_256(pretrained=False, **kwargs): """ Lambda-ResNet-26-R-T. Lambda layers w/ rel pos embed in last two stages. """ kwargs.setdefault('img_size', 256) - return _create_byoanet('lambda_resnet26rt_256', pretrained=pretrained, **kwargs) + return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs) diff --git a/timm/models/layers/lambda_layer.py b/timm/models/layers/lambda_layer.py index d298c1aa..fd174855 100644 --- a/timm/models/layers/lambda_layer.py +++ b/timm/models/layers/lambda_layer.py @@ -24,18 +24,30 @@ import torch from torch import nn import torch.nn.functional as F +from .helpers import to_2tuple from .weight_init import trunc_normal_ +def rel_pos_indices(size): + size = to_2tuple(size) + pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1) + rel_pos = pos[:, None, :] - pos[:, :, None] + rel_pos[0] += size[0] - 1 + rel_pos[1] += size[1] - 1 + return rel_pos # 2, H * W, H * W + + class LambdaLayer(nn.Module): - """Lambda Layer w/ lambda conv position embedding + """Lambda Layer Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` - https://arxiv.org/abs/2102.08602 + + NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add. """ def __init__( self, - dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False): + dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False): super().__init__() self.dim = dim self.dim_out = dim_out or dim @@ -43,7 +55,6 @@ class LambdaLayer(nn.Module): self.num_heads = num_heads assert self.dim_out % num_heads == 0, ' should be divided by num_heads' self.dim_v = self.dim_out // num_heads # value depth 'v' - self.r = r # relative position neighbourhood (lambda conv kernel size) self.qkv = nn.Conv2d( dim, @@ -52,8 +63,19 @@ class LambdaLayer(nn.Module): self.norm_q = nn.BatchNorm2d(num_heads * dim_head) self.norm_v = nn.BatchNorm2d(self.dim_v) - # NOTE currently only supporting the local lambda convolutions for positional - self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0)) + if r is not None: + # local lambda convolution for pos + self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0)) + self.pos_emb = None + self.rel_pos = None + else: + # relative pos embedding + assert feat_size is not None + feat_size = to_2tuple(feat_size) + rel_size = [2 * s - 1 for s in feat_size] + self.conv_lambda = None + self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_k)) + self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False) self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() @@ -61,12 +83,14 @@ class LambdaLayer(nn.Module): def reset_parameters(self): trunc_normal_(self.qkv.weight, std=self.dim ** -0.5) - trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5) + if self.conv_lambda is not None: + trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5) + if self.pos_emb is not None: + trunc_normal_(self.pos_emb, std=.02) def forward(self, x): B, C, H, W = x.shape M = H * W - qkv = self.qkv(x) q, k, v = torch.split(qkv, [ self.num_heads * self.dim_k, self.dim_k, self.dim_v], dim=1) @@ -77,10 +101,15 @@ class LambdaLayer(nn.Module): content_lam = k @ v # B, K, V content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V - position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K - position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V + if self.pos_emb is None: + position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K + position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V + else: + # FIXME relative pos embedding path not fully verified + pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1) + position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V - out = (content_out + position_out).transpose(3, 1).reshape(B, C, H, W) # B, C (num_heads * V), H, W + out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W out = self.pool(out) return out From b1c2e3eb92c85e460ca90133b93bbbf5476f927e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 30 Sep 2021 23:19:05 -0700 Subject: [PATCH 13/20] Match rel_pos_indices attr rename in conv branch --- timm/models/layers/lambda_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/layers/lambda_layer.py b/timm/models/layers/lambda_layer.py index fd174855..eeb77e45 100644 --- a/timm/models/layers/lambda_layer.py +++ b/timm/models/layers/lambda_layer.py @@ -67,7 +67,7 @@ class LambdaLayer(nn.Module): # local lambda convolution for pos self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0)) self.pos_emb = None - self.rel_pos = None + self.rel_pos_indices = None else: # relative pos embedding assert feat_size is not None From d9abfa48df3090e6157fefa22e9ae05c28e62d07 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 1 Oct 2021 13:43:55 -0700 Subject: [PATCH 14/20] Make broadcast_buffers disable its own flag for now (needs more testing on interaction with dist_bn) --- train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index d95611ad..785b99e2 100755 --- a/train.py +++ b/train.py @@ -270,6 +270,8 @@ parser.add_argument('--apex-amp', action='store_true', default=False, help='Use NVIDIA Apex AMP mixed precision') parser.add_argument('--native-amp', action='store_true', default=False, help='Use Native Torch AMP mixed precision') +parser.add_argument('--no-ddp-bb', action='store_true', default=False, + help='Force broadcast buffers for native DDP to off.') parser.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') parser.add_argument('--pin-mem', action='store_true', default=False, @@ -463,7 +465,7 @@ def main(): else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") - model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.dist_bn) + model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch From 007bc3932375a71e7c1e4aa0b7f0c79f3bb79f56 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 2 Oct 2021 15:51:42 -0700 Subject: [PATCH 15/20] Some halo and bottleneck attn code cleanup, add halonet50ts weights, use optimal crop ratios --- timm/models/byoanet.py | 9 ++-- timm/models/layers/bottleneck_attn.py | 12 ++--- timm/models/layers/halo_attn.py | 68 +++++++++++++++------------ 3 files changed, 48 insertions(+), 41 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 056813ef..f58b724c 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -3,7 +3,7 @@ A flexible network w/ dataclass based config for stacking NN blocks including self-attention (or similar) layers. -Currently used to implement experimential variants of: +Currently used to implement experimental variants of: * Bottleneck Transformers * Lambda ResNets * HaloNets @@ -46,15 +46,16 @@ default_cfgs = { 'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet26t': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_256-9b4bf0b3.pth', - input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'sehalonet33ts': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'halonet50ts': _cfg( - url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_256_ra3-f07eab9f.pth', + input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'eca_halonext26ts': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_256-1e55880b.pth', - input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'lambda_resnet26t': _cfg( url='', diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index bf6af675..61859f9c 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -118,12 +118,12 @@ class BottleneckAttn(nn.Module): x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2) q, k, v = torch.split(x, self.num_heads, dim=1) - attn_logits = (q @ k.transpose(-1, -2)) * self.scale - attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W + attn = (q @ k.transpose(-1, -2)) * self.scale + attn = attn + self.pos_embed(q) # B, num_heads, H * W, H * W + attn = attn.softmax(dim=-1) - attn_out = attn_logits.softmax(dim=-1) - attn_out = (attn_out @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W - attn_out = self.pool(attn_out) - return attn_out + out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W + out = self.pool(out) + return out diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index d298fc0b..034c66a8 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -106,22 +106,23 @@ class HaloAttn(nn.Module): assert dim_out % num_heads == 0 self.stride = stride self.num_heads = num_heads - self.dim_head = dim_head or dim // num_heads - self.dim_qk = num_heads * self.dim_head - self.dim_v = dim_out + self.dim_head_qk = dim_head or dim_out // num_heads + self.dim_head_v = dim_out // self.num_heads + self.dim_out_qk = num_heads * self.dim_head_qk + self.dim_out_v = num_heads * self.dim_head_v self.block_size = block_size self.halo_size = halo_size self.win_size = block_size + halo_size * 2 # neighbourhood window size - self.scale = self.dim_head ** -0.5 + self.scale = self.dim_head_qk ** -0.5 # FIXME not clear if this stride behaviour is what the paper intended # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving # data in unfolded block form. I haven't wrapped my head around how that'd look. - self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias) - self.kv = nn.Conv2d(dim, self.dim_qk + self.dim_v, 1, bias=qkv_bias) + self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.stride, bias=qkv_bias) + self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias) self.pos_embed = PosEmbedRel( - block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale) + block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale) self.reset_parameters() @@ -143,37 +144,42 @@ class HaloAttn(nn.Module): q = self.q(x) # unfold - q = q.reshape(-1, self.dim_head, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4) + q = q.reshape(-1, self.dim_head_qk, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4) # B, num_heads * dim_head * block_size ** 2, num_blocks - q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3) + q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3) # B * num_heads, num_blocks, block_size ** 2, dim_head kv = self.kv(x) # generate overlapping windows for kv kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]) kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape( - B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), num_blocks, -1).permute(0, 2, 3, 1) - # NOTE these two alternatives are equivalent, but above is the best balance of performance and clarity - # if self.stride_tricks: - # kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous() - # kv = kv.as_strided(( - # B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks), - # stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size)) - # else: - # kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) - # kv = kv.reshape( - # B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3) - k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1) - # B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads - - attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied? - attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2 - - attn_out = attn_logits.softmax(dim=-1) - attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks + B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1) + k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1) + # B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v + attn = (q @ k.transpose(-1, -2)) * self.scale + attn = attn + self.pos_embed(q) # B * num_heads, num_blocks, block_size ** 2, win_size ** 2 + attn = attn.softmax(dim=-1) + + out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks # fold - attn_out = attn_out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks) - attn_out = attn_out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_v, H // self.stride, W // self.stride) + out = out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks) + out = out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_out_v, H // self.stride, W // self.stride) # B, dim_out, H // stride, W // stride - return attn_out + return out + + +""" Two alternatives for overlapping windows. + +`.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold() + + if self.stride_tricks: + kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous() + kv = kv.as_strided(( + B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks), + stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size)) + else: + kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) + kv = kv.reshape( + B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3) +""" From b2094f4ee845d89aca8de65ae9b6ae09829a8b8e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 3 Oct 2021 17:31:22 -0700 Subject: [PATCH 16/20] support bits checkpoints in avg/load --- avg_checkpoints.py | 4 ++++ timm/models/helpers.py | 15 +++++++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/avg_checkpoints.py b/avg_checkpoints.py index 1f7604b0..ea8bbe84 100755 --- a/avg_checkpoints.py +++ b/avg_checkpoints.py @@ -41,6 +41,10 @@ def checkpoint_metric(checkpoint_path): metric = None if 'metric' in checkpoint: metric = checkpoint['metric'] + elif 'metrics' in checkpoint and 'metric_name' in checkpoint: + metrics = checkpoint['metrics'] + print(metrics) + metric = metrics[checkpoint['metric_name']] return metric diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 662a7a48..bd97cf20 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -24,13 +24,20 @@ _logger = logging.getLogger(__name__) def load_state_dict(checkpoint_path, use_ema=False): if checkpoint_path and os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') - state_dict_key = 'state_dict' + state_dict_key = '' if isinstance(checkpoint, dict): - if use_ema and 'state_dict_ema' in checkpoint: + if use_ema and checkpoint.get('state_dict_ema', None) is not None: state_dict_key = 'state_dict_ema' - if state_dict_key and state_dict_key in checkpoint: + elif use_ema and checkpoint.get('model_ema', None) is not None: + state_dict_key = 'model_ema' + elif 'state_dict' in checkpoint: + state_dict_key = 'state_dict' + elif 'model' in checkpoint: + state_dict_key = 'model' + if state_dict_key: + state_dict = checkpoint[state_dict_key] new_state_dict = OrderedDict() - for k, v in checkpoint[state_dict_key].items(): + for k, v in state_dict.items(): # strip `module.` prefix name = k[7:] if k.startswith('module') else k new_state_dict[name] = v From 64495505b7bbf5438672d53804efaaa634bad710 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 3 Oct 2021 17:31:39 -0700 Subject: [PATCH 17/20] Add updated lambda resnet26 and botnet26 checkpoints with fixes applied --- timm/models/byoanet.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index f58b724c..3c43378a 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -34,8 +34,8 @@ def _cfg(url='', **kwargs): default_cfgs = { # GPU-Efficient (ResNet) weights 'botnet26t_256': _cfg( - url='', - fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth', + fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95), 'botnet50ts_256': _cfg( url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), @@ -58,13 +58,13 @@ default_cfgs = { input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'lambda_resnet26t': _cfg( - url='', - min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_a2h_256-25ded63d.pth', + min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95), 'lambda_resnet50ts': _cfg( url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), 'lambda_resnet26rpt_256': _cfg( - url='', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_a2h_256-482adad8.pth', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), } From cc9bedf373209664854dc400cbe5801e3fc1e6e9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 3 Oct 2021 17:32:02 -0700 Subject: [PATCH 18/20] Add initial ResNet Strikes Back weights for ResNet50 and ResNetV2-50 models --- timm/models/resnet.py | 4 ++-- timm/models/resnetv2.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index dad42f38..1f0716c5 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -53,11 +53,11 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet26t_256_ra2-6f6fa748.pth', interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)), 'resnet50': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-00ca2c6a.pth', interpolation='bicubic'), 'resnet50d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth', - interpolation='bicubic', first_conv='conv1.0'), + interpolation='bicubic', first_conv='conv1.0', crop_pct=0.95), 'resnet50t': _cfg( url='', interpolation='bicubic', first_conv='conv1.0'), diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 2b5121a2..fe7fc466 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -105,7 +105,8 @@ default_cfgs = { input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic'), 'resnetv2_50': _cfg( - interpolation='bicubic'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_50_a1_h-000cdf49.pth', + interpolation='bicubic', crop_pct=0.95), 'resnetv2_50d': _cfg( interpolation='bicubic', first_conv='stem.conv1'), 'resnetv2_50t': _cfg( From da0d39bedd873c17d6cd2af50d78cbed564019c7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 3 Oct 2021 17:33:16 -0700 Subject: [PATCH 19/20] Update default crop_pct for byoanet --- timm/models/byoanet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 3c43378a..61f94490 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -23,7 +23,7 @@ __all__ = [] def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'crop_pct': 0.95, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc', 'fixed_input_size': False, 'min_input_size': (3, 224, 224), @@ -35,7 +35,7 @@ default_cfgs = { # GPU-Efficient (ResNet) weights 'botnet26t_256': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth', - fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95), + fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'botnet50ts_256': _cfg( url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), @@ -59,7 +59,7 @@ default_cfgs = { 'lambda_resnet26t': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_a2h_256-25ded63d.pth', - min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95), + min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), 'lambda_resnet50ts': _cfg( url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), From 93901e992f7bcb6bdb46729a307f67e39dd9b5fd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 3 Oct 2021 17:34:57 -0700 Subject: [PATCH 20/20] Version bump to 0.5.0 for pending release post RSB and ATTN updates --- timm/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/version.py b/timm/version.py index 779b9fc3..2b8877c5 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.4.13' +__version__ = '0.5.0'