From 34f382f8f6583a80cb0a169c275bf0806d95ca06 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 1 Jul 2022 14:50:36 -0700 Subject: [PATCH 01/26] move dataconfig before script, scripting killing metadata now (PyTorch 1.12? just nvfuser?) --- benchmark.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmark.py b/benchmark.py index f348fcb9..1362eeab 100755 --- a/benchmark.py +++ b/benchmark.py @@ -225,11 +225,12 @@ class BenchmarkRunner: self.num_classes = self.model.num_classes self.param_count = count_params(self.model) _logger.info('Model %s created, param count: %d' % (model_name, self.param_count)) + + data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) self.scripted = False if torchscript: self.model = torch.jit.script(self.model) self.scripted = True - data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) self.input_size = data_config['input_size'] self.batch_size = kwargs.pop('batch_size', 256) From a050fde5cde892404a5b77973a5916cdd7b602ab Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 1 Jul 2022 15:03:28 -0700 Subject: [PATCH 02/26] Add resnet10t (basic block) and resnet14t (bottleneck) with 1,1,1,1 repeats --- timm/models/resnet.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index a7f0c0f6..476ffe91 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -723,6 +723,24 @@ def _create_resnet(variant, pretrained=False, **kwargs): return build_model_with_cfg(ResNet, variant, pretrained, **kwargs) +@register_model +def resnet10t(pretrained=False, **kwargs): + """Constructs a ResNet-10-T model. + """ + model_args = dict( + block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) + return _create_resnet('resnet10t', pretrained, **model_args) + + +@register_model +def resnet14t(pretrained=False, **kwargs): + """Constructs a ResNet-14-T model. + """ + model_args = dict( + block=Bottleneck, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) + return _create_resnet('resnet14t', pretrained, **model_args) + + @register_model def resnet18(pretrained=False, **kwargs): """Constructs a ResNet-18 model. From 82c311d0821643e8613b0b18f6d0f14088a79459 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 1 Jul 2022 15:14:01 -0700 Subject: [PATCH 03/26] Add more experimental darknet and 'cs2' darknet variants (different cross stage setup, closer to newer YOLO backbones) for train trials. --- timm/models/cspnet.py | 384 ++++++++++++++++++++++++++---- timm/models/layers/conv_bn_act.py | 19 +- 2 files changed, 352 insertions(+), 51 deletions(-) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index f8a87fab..095e4701 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -16,6 +16,7 @@ from functools import partial import torch import torch.nn as nn +import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP @@ -46,11 +47,21 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnext50_ra_224-648b4713.pth', input_size=(3, 224, 224), pool_size=(7, 7), crop_pct=0.875 # FIXME I trained this at 224x224, not 256 like ref impl ), - 'cspresnext50_iabn': _cfg(url=''), 'cspdarknet53': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspdarknet53_ra_256-d05c7c21.pth'), - 'cspdarknet53_iabn': _cfg(url=''), + + 'darknet17': _cfg(url=''), + 'darknet21': _cfg(url=''), 'darknet53': _cfg(url=''), + + 'cs2darknet_m': _cfg( + url=''), + 'cs2darknet_l': _cfg( + url=''), + 'cs2darknet_f_m': _cfg( + url=''), + 'cs2darknet_f_l': _cfg( + url=''), } @@ -116,6 +127,37 @@ model_cfgs = dict( down_growth=True, ) ), + darknet17=dict( + stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), + stage=dict( + out_chs=(64, 128, 256, 512, 1024), + depth=(1,) * 5, + stride=(2,) * 5, + bottle_ratio=(0.5,) * 5, + block_ratio=(1.,) * 5, + ) + ), + darknet21=dict( + stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), + stage=dict( + out_chs=(64, 128, 256, 512, 1024), + depth=(1, 1, 1, 2, 2), + stride=(2,) * 5, + bottle_ratio=(0.5,) * 5, + block_ratio=(1.,) * 5, + ) + ), + sedarknet21=dict( + stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), + stage=dict( + out_chs=(64, 128, 256, 512, 1024), + depth=(1, 1, 1, 2, 2), + stride=(2,) * 5, + bottle_ratio=(0.5,) * 5, + block_ratio=(1.,) * 5, + attn_layer=('se',) * 5, + ) + ), darknet53=dict( stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), stage=dict( @@ -125,13 +167,81 @@ model_cfgs = dict( bottle_ratio=(0.5,) * 5, block_ratio=(1.,) * 5, ) + ), + + darknetaa53=dict( + stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), + stage=dict( + out_chs=(64, 128, 256, 512, 1024), + depth=(1, 2, 8, 8, 4), + stride=(2,) * 5, + bottle_ratio=(0.5,) * 5, + block_ratio=(1.,) * 5, + avg_down=True, + ), + ), + + cs2darknet_m=dict( + stem=dict(out_chs=(24, 48), kernel_size=3, stride=2, pool=''), + stage=dict( + out_chs=(96, 192, 384, 768), + depth=(2, 4, 6, 2), + stride=(2,) * 4, + bottle_ratio=(1.,) * 4, + block_ratio=(0.5,) * 4, + avg_down=False, + ), + ), + + cs2darknet_f_m=dict( + stem=dict(out_chs=48, kernel_size=6, stride=2, padding=2, pool=''), + stage=dict( + out_chs=(96, 192, 384, 768), + depth=(2, 4, 6, 2), + stride=(2,) * 4, + bottle_ratio=(1.,) * 4, + block_ratio=(0.5,) * 4, + avg_down=False, + ), + ), + + cs2darknet_l=dict( + stem=dict(out_chs=(32, 64), kernel_size=3, stride=2, pool=''), + stage=dict( + out_chs=(128, 256, 512, 1024), + depth=(3, 6, 9, 3), + stride=(2,) * 4, + bottle_ratio=(1.,) * 4, + block_ratio=(0.5,) * 4, + avg_down=False, + ), + ), + + cs2darknet_f_l=dict( + stem=dict(out_chs=64, kernel_size=6, stride=2, padding=2, pool=''), + stage=dict( + out_chs=(128, 256, 512, 1024), + depth=(3, 6, 9, 3), + stride=(2,) * 4, + bottle_ratio=(1.,) * 4, + block_ratio=(0.5,) * 4, + avg_down=False, + ), ) ) def create_stem( - in_chans=3, out_chs=32, kernel_size=3, stride=2, pool='', - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None): + in_chans=3, + out_chs=32, + kernel_size=3, + stride=2, + pool='', + padding='', + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + aa_layer=None +): stem = nn.Sequential() if not isinstance(out_chs, (tuple, list)): out_chs = [out_chs] @@ -140,8 +250,12 @@ def create_stem( for i, out_c in enumerate(out_chs): conv_name = f'conv{i + 1}' stem.add_module(conv_name, ConvNormAct( - in_c, out_c, kernel_size, stride=stride if i == 0 else 1, - act_layer=act_layer, norm_layer=norm_layer)) + in_c, out_c, kernel_size, + stride=stride if i == 0 else 1, + padding=padding if i == 0 else '', + act_layer=act_layer, + norm_layer=norm_layer + )) in_c = out_c last_conv = conv_name if pool: @@ -158,9 +272,20 @@ class ResBottleneck(nn.Module): """ def __init__( - self, in_chs, out_chs, dilation=1, bottle_ratio=0.25, groups=1, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_last=False, - attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + self, + in_chs, + out_chs, + dilation=1, + bottle_ratio=0.25, + groups=1, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + attn_last=False, + attn_layer=None, + aa_layer=None, + drop_block=None, + drop_path=None + ): super(ResBottleneck, self).__init__() mid_chs = int(round(out_chs * bottle_ratio)) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) @@ -173,7 +298,7 @@ class ResBottleneck(nn.Module): self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs) self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None self.drop_path = drop_path - self.act3 = act_layer(inplace=True) + self.act3 = act_layer() def zero_init_last(self): nn.init.zeros_(self.conv3.bn.weight) @@ -201,9 +326,19 @@ class DarkBlock(nn.Module): """ def __init__( - self, in_chs, out_chs, dilation=1, bottle_ratio=0.5, groups=1, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, - drop_block=None, drop_path=None): + self, + in_chs, + out_chs, + dilation=1, + bottle_ratio=0.5, + groups=1, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + attn_layer=None, + aa_layer=None, + drop_block=None, + drop_path=None + ): super(DarkBlock, self).__init__() mid_chs = int(round(out_chs * bottle_ratio)) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) @@ -211,7 +346,7 @@ class DarkBlock(nn.Module): self.conv2 = ConvNormActAa( mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups, aa_layer=aa_layer, drop_layer=drop_block, **ckwargs) - self.attn = create_attn(attn_layer, channels=out_chs) + self.attn = create_attn(attn_layer, channels=out_chs, act_layer=act_layer) self.drop_path = drop_path def zero_init_last(self): @@ -232,23 +367,44 @@ class DarkBlock(nn.Module): class CrossStage(nn.Module): """Cross Stage.""" def __init__( - self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1., - groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None, - block_fn=ResBottleneck, **block_kwargs): + self, + in_chs, + out_chs, + stride, + dilation, + depth, + block_ratio=1., + bottle_ratio=1., + exp_ratio=1., + groups=1, + first_dilation=None, + avg_down=False, + down_growth=False, + cross_linear=False, + block_dpr=None, + block_fn=ResBottleneck, + **block_kwargs + ): super(CrossStage, self).__init__() first_dilation = first_dilation or dilation down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels - exp_chs = int(round(out_chs * exp_ratio)) + self.exp_chs = exp_chs = int(round(out_chs * exp_ratio)) block_out_chs = int(round(out_chs * block_ratio)) conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) if stride != 1 or first_dilation != dilation: - self.conv_down = ConvNormActAa( - in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, - aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs) + if avg_down: + self.conv_down = nn.Sequential( + nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling + ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) + ) + else: + self.conv_down = ConvNormActAa( + in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, + aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs) prev_chs = down_chs else: - self.conv_down = None + self.conv_down = nn.Identity() prev_chs = in_chs # FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also, @@ -269,30 +425,115 @@ class CrossStage(nn.Module): self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs) def forward(self, x): - if self.conv_down is not None: - x = self.conv_down(x) + x = self.conv_down(x) x = self.conv_exp(x) - split = x.shape[1] // 2 - xs, xb = x[:, :split], x[:, split:] + xs, xb = x.split(self.exp_chs // 2, dim=1) xb = self.blocks(xb) xb = self.conv_transition_b(xb).contiguous() out = self.conv_transition(torch.cat([xs, xb], dim=1)) return out +class CrossStage2(nn.Module): + """Cross Stage v2. + Similar to CrossStage, but with one transition conv for the concat output. + """ + def __init__( + self, + in_chs, + out_chs, + stride, + dilation, + depth, + block_ratio=1., + bottle_ratio=1., + exp_ratio=1., + groups=1, + first_dilation=None, + avg_down=False, + down_growth=False, + cross_linear=False, + block_dpr=None, + block_fn=ResBottleneck, + **block_kwargs + ): + super(CrossStage2, self).__init__() + first_dilation = first_dilation or dilation + down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels + self.exp_chs = exp_chs = int(round(out_chs * exp_ratio)) + block_out_chs = int(round(out_chs * block_ratio)) + conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) + + if stride != 1 or first_dilation != dilation: + if avg_down: + self.conv_down = nn.Sequential( + nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling + ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) + ) + else: + self.conv_down = ConvNormActAa( + in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, + aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs) + prev_chs = down_chs + else: + self.conv_down = None + prev_chs = in_chs + + # expansion conv + self.conv_exp = ConvNormAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs) + prev_chs = exp_chs // 2 # expanded output is split in 2 for blocks and cross stage + + self.blocks = nn.Sequential() + for i in range(depth): + drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None + self.blocks.add_module(str(i), block_fn( + prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + prev_chs = block_out_chs + + # transition convs + self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs) + + def forward(self, x): + x = self.conv_down(x) + x = self.conv_exp(x) + x1, x2 = x.split(self.exp_chs // 2, dim=1) + x1 = self.blocks(x1) + out = self.conv_transition(torch.cat([x1, x2], dim=1)) + return out + + class DarkStage(nn.Module): """DarkNet stage.""" def __init__( - self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., groups=1, - first_dilation=None, block_fn=ResBottleneck, block_dpr=None, **block_kwargs): + self, + in_chs, + out_chs, + stride, + dilation, + depth, + block_ratio=1., + bottle_ratio=1., + groups=1, + first_dilation=None, + avg_down=False, + block_fn=ResBottleneck, + block_dpr=None, + **block_kwargs + ): super(DarkStage, self).__init__() first_dilation = first_dilation or dilation + conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) - self.conv_down = ConvNormActAa( - in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, - act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'), - aa_layer=block_kwargs.get('aa_layer', None)) + if avg_down: + self.conv_down = nn.Sequential( + nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling + ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) + ) + else: + self.conv_down = ConvNormActAa( + in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, + aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs) prev_chs = out_chs block_out_chs = int(round(out_chs * block_ratio)) @@ -318,6 +559,8 @@ def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.): cfg['down_growth'] = (cfg['down_growth'],) * num_stages if 'cross_linear' in cfg and not isinstance(cfg['cross_linear'], (list, tuple)): cfg['cross_linear'] = (cfg['cross_linear'],) * num_stages + if 'avg_down' in cfg and not isinstance(cfg['avg_down'], (list, tuple)): + cfg['avg_down'] = (cfg['avg_down'],) * num_stages cfg['block_dpr'] = [None] * num_stages if not drop_path_rate else \ [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg['depth'])).split(cfg['depth'])] stage_strides = [] @@ -352,9 +595,20 @@ class CspNet(nn.Module): """ def __init__( - self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0., - act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_path_rate=0., - zero_init_last=True, stage_fn=CrossStage, block_fn=ResBottleneck): + self, + cfg, + in_chans=3, + num_classes=1000, + output_stride=32, + global_pool='avg', + act_layer=nn.LeakyReLU, + norm_layer=nn.BatchNorm2d, + aa_layer=None, + drop_rate=0., + drop_path_rate=0., + zero_init_last=True, + stage_fn=CrossStage, + block_fn=ResBottleneck): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate @@ -427,23 +681,22 @@ class CspNet(nn.Module): def _init_weights(module, name, zero_init_last=False): if isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(module, nn.BatchNorm2d): - nn.init.ones_(module.weight) - nn.init.zeros_(module.bias) + if module.bias is not None: + nn.init.zeros_(module.bias) elif isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.01) - nn.init.zeros_(module.bias) + if module.bias is not None: + nn.init.zeros_(module.bias) elif zero_init_last and hasattr(module, 'zero_init_last'): module.zero_init_last() def _create_cspnet(variant, pretrained=False, **kwargs): - cfg_variant = variant.split('_')[0] # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5] out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4, 5) if 'darknet' in variant else (0, 1, 2, 3, 4)) return build_model_with_cfg( CspNet, variant, pretrained, - model_cfg=model_cfgs[cfg_variant], + model_cfg=model_cfgs[variant], feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs) @@ -469,22 +722,55 @@ def cspresnext50(pretrained=False, **kwargs): @register_model -def cspresnext50_iabn(pretrained=False, **kwargs): - norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu') - return _create_cspnet('cspresnext50_iabn', pretrained=pretrained, norm_layer=norm_layer, **kwargs) +def cspdarknet53(pretrained=False, **kwargs): + return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs) @register_model -def cspdarknet53(pretrained=False, **kwargs): - return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs) +def darknet17(pretrained=False, **kwargs): + return _create_cspnet('darknet17', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) @register_model -def cspdarknet53_iabn(pretrained=False, **kwargs): - norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu') - return _create_cspnet('cspdarknet53_iabn', pretrained=pretrained, block_fn=DarkBlock, norm_layer=norm_layer, **kwargs) +def darknet21(pretrained=False, **kwargs): + return _create_cspnet('darknet21', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + + +@register_model +def sedarknet21(pretrained=False, **kwargs): + return _create_cspnet('sedarknet21', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) @register_model def darknet53(pretrained=False, **kwargs): return _create_cspnet('darknet53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + + +@register_model +def darknetaa53(pretrained=False, **kwargs): + return _create_cspnet( + 'darknetaa53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + + +@register_model +def cs2darknet_m(pretrained=False, **kwargs): + return _create_cspnet( + 'cs2darknet_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs) + + +@register_model +def cs2darknet_l(pretrained=False, **kwargs): + return _create_cspnet( + 'cs2darknet_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs) + + +@register_model +def cs2darknet_f_m(pretrained=False, **kwargs): + return _create_cspnet( + 'cs2darknet_f_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs) + + +@register_model +def cs2darknet_f_l(pretrained=False, **kwargs): + return _create_cspnet( + 'cs2darknet_f_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs) \ No newline at end of file diff --git a/timm/models/layers/conv_bn_act.py b/timm/models/layers/conv_bn_act.py index af010573..9e7c64b8 100644 --- a/timm/models/layers/conv_bn_act.py +++ b/timm/models/layers/conv_bn_act.py @@ -2,6 +2,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ +import functools from torch import nn as nn from .create_conv2d import create_conv2d @@ -40,12 +41,26 @@ class ConvNormAct(nn.Module): ConvBnAct = ConvNormAct +def create_aa(aa_layer, channels, stride=2, enable=True): + if not aa_layer or not enable: + return nn.Identity() + if isinstance(aa_layer, functools.partial): + if issubclass(aa_layer.func, nn.AvgPool2d): + return aa_layer() + else: + return aa_layer(channels) + elif issubclass(aa_layer, nn.AvgPool2d): + return aa_layer(stride) + else: + return aa_layer(channels=channels, stride=stride) + + class ConvNormActAa(nn.Module): def __init__( self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None): super(ConvNormActAa, self).__init__() - use_aa = aa_layer is not None + use_aa = aa_layer is not None and stride == 2 self.conv = create_conv2d( in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, @@ -56,7 +71,7 @@ class ConvNormActAa(nn.Module): # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) - self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else nn.Identity() + self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) @property def in_channels(self): From 7a9c6811c91123f84af963e5302a9d18c7c33716 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 1 Jul 2022 15:15:39 -0700 Subject: [PATCH 04/26] Add eps arg to LayerNorm2d, add 'tf' (tensorflow) variant of trunc_normal_ that applies scale/shift after sampling (instead of needing to move a/b) --- timm/models/layers/__init__.py | 2 +- timm/models/layers/norm.py | 4 ++-- timm/models/layers/weight_init.py | 36 ++++++++++++++++++++++++++++++- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index b1a64db3..b1f452ff 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -39,4 +39,4 @@ from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .trace_utils import _assert, _float_to_int -from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_ +from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index 85297420..345f67bc 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -16,8 +16,8 @@ class GroupNorm(nn.GroupNorm): class LayerNorm2d(nn.LayerNorm): """ LayerNorm for channels of '2D' spatial BCHW tensors """ - def __init__(self, num_channels): - super().__init__(num_channels) + def __init__(self, num_channels, eps=1e-6): + super().__init__(num_channels, eps=eps) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.layer_norm( diff --git a/timm/models/layers/weight_init.py b/timm/models/layers/weight_init.py index 305a2fd0..4a160931 100644 --- a/timm/models/layers/weight_init.py +++ b/timm/models/layers/weight_init.py @@ -49,6 +49,11 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. + + NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are + applied while sampling the normal with mean/std applied, therefore a, b args + should be adjusted to match the range of mean, std args. + Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution @@ -62,6 +67,35 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): return _no_grad_trunc_normal_(tensor, mean, std, a, b) +def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsquently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + _no_grad_trunc_normal_(tensor, 0, 1.0, a, b) + with torch.no_grad(): + tensor.mul_(std).add_(mean) + return tensor + + def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == 'fan_in': @@ -75,7 +109,7 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) - trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978) elif distribution == "normal": tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": From 6064d16a2dfe89b1d3706df338cecfdcee395d1f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 1 Jul 2022 15:16:41 -0700 Subject: [PATCH 05/26] Add initial EdgeNeXt import. Significant cleanup / reorg (like ConvNeXt). Fix #1320 * edgenext refactored for torchscript compat, stage base organization * slight refactor of ConvNeXt to match some EdgeNeXt additions * remove use of funky LayerNorm layer in ConvNeXt and just use nn.LayerNorm and LayerNorm2d (permute) --- timm/models/__init__.py | 1 + timm/models/convnext.py | 190 ++++++++------ timm/models/edgenext.py | 545 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 665 insertions(+), 71 deletions(-) create mode 100644 timm/models/edgenext.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 4f81683a..195e451b 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -12,6 +12,7 @@ from .deit import * from .densenet import * from .dla import * from .dpn import * +from .edgenext import * from .efficientnet import * from .ghostnet import * from .gluon_resnet import * diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 1aacef2b..662695c7 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -19,7 +19,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .fx_features import register_notrace_module from .helpers import named_apply, build_model_with_cfg, checkpoint_seq -from .layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp +from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d from .registry import register_model @@ -44,6 +44,7 @@ default_cfgs = dict( convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), convnext_nano_hnf=_cfg(url=''), + convnext_nano_ols=_cfg(url=''), convnext_tiny_hnf=_cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', crop_pct=0.95), @@ -88,35 +89,6 @@ default_cfgs = dict( ) -def _is_contiguous(tensor: torch.Tensor) -> bool: - # jit is oh so lovely :/ - # if torch.jit.is_tracing(): - # return True - if torch.jit.is_scripting(): - return tensor.is_contiguous() - else: - return tensor.is_contiguous(memory_format=torch.contiguous_format) - - -@register_notrace_module -class LayerNorm2d(nn.LayerNorm): - r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). - """ - - def __init__(self, normalized_shape, eps=1e-6): - super().__init__(normalized_shape, eps=eps) - - def forward(self, x) -> torch.Tensor: - if _is_contiguous(x): - return F.layer_norm( - x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) - else: - s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) - x = (x - u) * torch.rsqrt(s + self.eps) - x = x * self.weight[:, None, None] + self.bias[:, None, None] - return x - - class ConvNeXtBlock(nn.Module): """ ConvNeXt Block There are two equivalent implementations: @@ -133,21 +105,39 @@ class ConvNeXtBlock(nn.Module): ls_init_value (float): Init value for Layer Scale. Default: 1e-6. """ - def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=False, mlp_ratio=4, norm_layer=None): + def __init__( + self, + dim, + dim_out=None, + stride=1, + mlp_ratio=4, + conv_mlp=False, + conv_bias=True, + ls_init_value=1e-6, + norm_layer=None, + act_layer=nn.GELU, + drop_path=0., + ): super().__init__() + dim_out = dim_out or dim if not norm_layer: norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) mlp_layer = ConvMlp if conv_mlp else Mlp self.use_conv_mlp = conv_mlp - self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv - self.norm = norm_layer(dim) - self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=nn.GELU) - self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None + self.shortcut_after_dw = stride > 1 + + self.conv_dw = create_conv2d(dim, dim_out, kernel_size=7, stride=stride, depthwise=True, bias=conv_bias) + self.norm = norm_layer(dim_out) + self.mlp = mlp_layer(dim_out, int(mlp_ratio * dim_out), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): shortcut = x x = self.conv_dw(x) + if self.shortcut_after_dw: + shortcut = x + if self.use_conv_mlp: x = self.norm(x) x = self.mlp(x) @@ -158,32 +148,55 @@ class ConvNeXtBlock(nn.Module): x = x.permute(0, 3, 1, 2) if self.gamma is not None: x = x.mul(self.gamma.reshape(1, -1, 1, 1)) + x = self.drop_path(x) + shortcut + #print('b', x.shape) return x class ConvNeXtStage(nn.Module): def __init__( - self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=False, - norm_layer=None, cl_norm_layer=None, cross_stage=False): + self, + in_chs, + out_chs, + stride=2, + depth=2, + drop_path_rates=None, + ls_init_value=1.0, + downsample_block=False, + conv_mlp=False, + conv_bias=True, + norm_layer=None, + norm_layer_cl=None + ): super().__init__() self.grad_checkpointing = False - if in_chs != out_chs or stride > 1: + if downsample_block or (in_chs == out_chs and stride == 1): + self.downsample = nn.Identity() + else: self.downsample = nn.Sequential( norm_layer(in_chs), - nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride), + nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride, bias=conv_bias), ) - else: - self.downsample = nn.Identity() - - dp_rates = dp_rates or [0.] * depth - self.blocks = nn.Sequential(*[ConvNeXtBlock( - dim=out_chs, drop_path=dp_rates[j], ls_init_value=ls_init_value, conv_mlp=conv_mlp, - norm_layer=norm_layer if conv_mlp else cl_norm_layer) - for j in range(depth)] - ) + in_chs = out_chs + + drop_path_rates = drop_path_rates or [0.] * depth + stage_blocks = [] + for i in range(depth): + stage_blocks.append(ConvNeXtBlock( + dim=in_chs, + dim_out=out_chs, + stride=stride if downsample_block and i == 0 else 1, + drop_path=drop_path_rates[i], + ls_init_value=ls_init_value, + conv_mlp=conv_mlp, + conv_bias=conv_bias, + norm_layer=norm_layer if conv_mlp else norm_layer_cl + )) + in_chs = out_chs + self.blocks = nn.Sequential(*stage_blocks) def forward(self, x): x = self.downsample(x) @@ -210,41 +223,57 @@ class ConvNeXt(nn.Module): """ def __init__( - self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4, - depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=False, stem_type='patch', - head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0., + self, + in_chans=3, + num_classes=1000, + global_pool='avg', + output_stride=32, + depths=(3, 3, 9, 3), + dims=(96, 192, 384, 768), + ls_init_value=1e-6, + stem_type='patch', + stem_kernel_size=4, + stem_stride=4, + head_init_scale=1., + head_norm_first=False, + downsample_block=False, + conv_mlp=False, + conv_bias=True, + norm_layer=None, + drop_rate=0., + drop_path_rate=0., ): super().__init__() assert output_stride == 32 if norm_layer is None: norm_layer = partial(LayerNorm2d, eps=1e-6) - cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) + norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) else: assert conv_mlp,\ 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' - cl_norm_layer = norm_layer + norm_layer_cl = norm_layer self.num_classes = num_classes self.drop_rate = drop_rate self.feature_info = [] - # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 + assert stem_type in ('patch', 'overlap') if stem_type == 'patch': + assert stem_kernel_size == stem_stride + # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 self.stem = nn.Sequential( - nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size), + nn.Conv2d(in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, bias=conv_bias), norm_layer(dims[0]) ) - curr_stride = patch_size - prev_chs = dims[0] else: self.stem = nn.Sequential( - nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1), - norm_layer(32), - nn.GELU(), - nn.Conv2d(32, 64, kernel_size=3, padding=1), + nn.Conv2d( + in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, + padding=stem_kernel_size // 2, bias=conv_bias), + norm_layer(dims[0]), ) - curr_stride = 2 - prev_chs = 64 + prev_chs = dims[0] + curr_stride = stem_stride self.stages = nn.Sequential() dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] @@ -256,16 +285,24 @@ class ConvNeXt(nn.Module): curr_stride *= stride out_chs = dims[i] stages.append(ConvNeXtStage( - prev_chs, out_chs, stride=stride, - depth=depths[i], dp_rates=dp_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp, - norm_layer=norm_layer, cl_norm_layer=cl_norm_layer) - ) + prev_chs, + out_chs, + stride=stride, + depth=depths[i], + drop_path_rates=dp_rates[i], + ls_init_value=ls_init_value, + downsample_block=downsample_block, + conv_mlp=conv_mlp, + conv_bias=conv_bias, + norm_layer=norm_layer, + norm_layer_cl=norm_layer_cl + )) prev_chs = out_chs # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) - self.num_features = prev_chs + # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() @@ -327,10 +364,11 @@ class ConvNeXt(nn.Module): def _init_weights(module, name=None, head_init_scale=1.0): if isinstance(module, nn.Conv2d): trunc_normal_(module.weight, std=.02) - nn.init.constant_(module.bias, 0) + if module.bias is not None: + nn.init.zeros_(module.bias) elif isinstance(module, nn.Linear): trunc_normal_(module.weight, std=.02) - nn.init.constant_(module.bias, 0) + nn.init.zeros_(module.bias) if name and 'head.' in name: module.weight.data.mul_(head_init_scale) module.bias.data.mul_(head_init_scale) @@ -371,11 +409,21 @@ def _create_convnext(variant, pretrained=False, **kwargs): @register_model def convnext_nano_hnf(pretrained=False, **kwargs): - model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs) + model_args = dict( + depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs) model = _create_convnext('convnext_nano_hnf', pretrained=pretrained, **model_args) return model +@register_model +def convnext_nano_ols(pretrained=False, **kwargs): + model_args = dict( + depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), downsample_block=True, + conv_bias=False, stem_type='overlap', stem_kernel_size=9, **kwargs) + model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args) + return model + + @register_model def convnext_tiny_hnf(pretrained=False, **kwargs): model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py new file mode 100644 index 00000000..0f8b0464 --- /dev/null +++ b/timm/models/edgenext.py @@ -0,0 +1,545 @@ +""" EdgeNeXt + +Paper: `EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications` + - https://arxiv.org/abs/2206.10589 + +Original code and weights from https://github.com/mmaaz60/EdgeNeXt + +Modifications and additions for timm by / Copyright 2022, Ross Wightman +""" +import math +import torch +from collections import OrderedDict +from functools import partial +from typing import Tuple + +from torch import nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import trunc_normal_tf_ +from timm.models.layers import DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d +from .helpers import named_apply, build_model_with_cfg, checkpoint_seq +from .registry import register_model + + +__all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), + 'crop_pct': 0.9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = dict( + edgenext_xx_small=_cfg( + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth"), + edgenext_x_small=_cfg( + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth"), + # edgenext_small=_cfg( + # url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_small.pth"), + edgenext_small=_cfg( # USI weights + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth", + crop_pct=0.95 + ), + + edgenext_small_rw=_cfg(), +) + + +class PositionalEncodingFourier(nn.Module): + def __init__(self, hidden_dim=32, dim=768, temperature=10000): + super().__init__() + self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1) + self.scale = 2 * math.pi + self.temperature = temperature + self.hidden_dim = hidden_dim + self.dim = dim + + def forward(self, shape: Tuple[int, int, int]): + inv_mask = ~torch.zeros(shape).to(device=self.token_projection.weight.device, dtype=torch.bool) + y_embed = inv_mask.cumsum(1, dtype=torch.float32) + x_embed = inv_mask.cumsum(2, dtype=torch.float32) + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=inv_mask.device) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), + pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), + pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + pos = self.token_projection(pos) + + return pos + + +class ConvBlock(nn.Module): + def __init__( + self, + dim, + dim_out=None, + kernel_size=7, + stride=1, + conv_bias=True, + expand_ratio=4, + ls_init_value=1e-6, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, drop_path=0., + ): + super().__init__() + dim_out = dim_out or dim + self.shortcut_after_dw = stride > 1 or dim != dim_out + + self.conv_dw = create_conv2d( + dim, dim_out, kernel_size=kernel_size, stride=stride, depthwise=True, bias=conv_bias) + self.norm = norm_layer(dim_out) + self.mlp = Mlp(dim_out, int(expand_ratio * dim_out), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + x = self.conv_dw(x) + if self.shortcut_after_dw: + shortcut = x + + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.mlp(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = shortcut + self.drop_path(x) + return x + + +class CrossCovarianceAttn(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0., + proj_drop=0. + ): + super().__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1) + q, k, v = qkv.unbind(0) + + # NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map + attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + @torch.jit.ignore + def no_weight_decay(self): + return {'temperature'} + + +class SplitTransposeBlock(nn.Module): + def __init__( + self, + dim, + num_scales=1, + num_heads=8, + expand_ratio=4, + use_pos_emb=True, + conv_bias=True, + qkv_bias=True, + ls_init_value=1e-6, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + drop_path=0., + attn_drop=0., + proj_drop=0. + ): + super().__init__() + width = max(int(math.ceil(dim / num_scales)), int(math.floor(dim // num_scales))) + self.width = width + self.num_scales = max(1, num_scales - 1) + + convs = [] + for i in range(self.num_scales): + convs.append(create_conv2d(width, width, kernel_size=3, depthwise=True, bias=conv_bias)) + self.convs = nn.ModuleList(convs) + + self.pos_embd = None + if use_pos_emb: + self.pos_embd = PositionalEncodingFourier(dim=dim) + self.norm_xca = norm_layer(dim) + self.gamma_xca = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None + self.xca = CrossCovarianceAttn( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop) + + self.norm = norm_layer(dim, eps=1e-6) + self.mlp = Mlp(dim, int(expand_ratio * dim), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + + # scales code re-written for torchscript as per my res2net fixes -rw + spx = torch.split(x, self.width, 1) + spo = [] + sp = spx[0] + for i, conv in enumerate(self.convs): + if i > 0: + sp = sp + spx[i] + sp = conv(sp) + spo.append(sp) + spo.append(spx[-1]) + x = torch.cat(spo, 1) + + # XCA + B, C, H, W = x.shape + x = x.reshape(B, C, H * W).permute(0, 2, 1) + if self.pos_embd is not None: + pos_encoding = self.pos_embd((B, H, W)).reshape(B, -1, x.shape[1]).permute(0, 2, 1) + x = x + pos_encoding + x = x + self.drop_path(self.gamma_xca * self.xca(self.norm_xca(x))) + x = x.reshape(B, H, W, C) + + # Inverted Bottleneck + x = self.norm(x) + x = self.mlp(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = shortcut + self.drop_path(x) + return x + + +class EdgeNeXtStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + stride=2, + depth=2, + num_global_blocks=1, + num_heads=4, + scales=2, + kernel_size=7, + expand_ratio=4, + use_pos_emb=False, + downsample_block=False, + conv_bias=True, + ls_init_value=1.0, + drop_path_rates=None, + norm_layer=LayerNorm2d, + norm_layer_cl=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU + ): + super().__init__() + self.grad_checkpointing = False + + if downsample_block or stride == 1: + self.downsample = nn.Identity() + else: + self.downsample = nn.Sequential( + norm_layer(in_chs), + nn.Conv2d(in_chs, out_chs, kernel_size=2, stride=2, bias=conv_bias) + ) + in_chs = out_chs + + stage_blocks = [] + for i in range(depth): + if i < depth - num_global_blocks: + stage_blocks.append( + ConvBlock( + dim=in_chs, + dim_out=out_chs, + stride=stride if downsample_block and i == 0 else 1, + conv_bias=conv_bias, + kernel_size=kernel_size, + expand_ratio=expand_ratio, + ls_init_value=ls_init_value, + drop_path=drop_path_rates[i], + norm_layer=norm_layer_cl, + act_layer=act_layer, + ) + ) + else: + stage_blocks.append( + SplitTransposeBlock( + dim=in_chs, + num_scales=scales, + num_heads=num_heads, + expand_ratio=expand_ratio, + use_pos_emb=use_pos_emb, + conv_bias=conv_bias, + ls_init_value=ls_init_value, + drop_path=drop_path_rates[i], + norm_layer=norm_layer_cl, + act_layer=act_layer, + ) + ) + in_chs = out_chs + self.blocks = nn.Sequential(*stage_blocks) + + def forward(self, x): + x = self.downsample(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class EdgeNeXt(nn.Module): + def __init__( + self, + in_chans=3, + num_classes=1000, + global_pool='avg', + dims=(24, 48, 88, 168), + depths=(3, 3, 9, 3), + global_block_counts=(0, 1, 1, 1), + kernel_sizes=(3, 5, 7, 9), + heads=(8, 8, 8, 8), + d2_scales=(2, 2, 3, 4), + use_pos_emb=(False, True, False, False), + ls_init_value=1e-6, + head_init_scale=1., + expand_ratio=4, + downsample_block=False, + conv_bias=True, + stem_type='patch', + head_norm_first=False, + act_layer=nn.GELU, + drop_path_rate=0., + drop_rate=0., + ): + super().__init__() + self.num_classes = num_classes + self.global_pool = global_pool + self.drop_rate = drop_rate + norm_layer = partial(LayerNorm2d, eps=1e-6) + norm_layer_cl = partial(nn.LayerNorm, eps=1e-6) + + assert stem_type in ('patch', 'overlap') + if stem_type == 'patch': + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, bias=conv_bias), + norm_layer(dims[0]), + ) + else: + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=9, stride=4, padding=9 // 2, bias=conv_bias), + norm_layer(dims[0]), + ) + + stages = [] + dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + in_chs = dims[0] + for i in range(4): + stages.append(EdgeNeXtStage( + in_chs=in_chs, + out_chs=dims[i], + stride=2 if i > 0 else 1, + depth=depths[i], + num_global_blocks=global_block_counts[i], + num_heads=heads[i], + drop_path_rates=dp_rates[i], + scales=d2_scales[i], + expand_ratio=expand_ratio, + kernel_size=kernel_sizes[i], + use_pos_emb=use_pos_emb[i], + ls_init_value=ls_init_value, + downsample_block=downsample_block, + conv_bias=conv_bias, + norm_layer=norm_layer, + norm_layer_cl=norm_layer_cl, + act_layer=act_layer, + )) + in_chs = dims[i] + self.stages = nn.Sequential(*stages) + + self.num_features = dims[-1] + self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() + self.head = nn.Sequential(OrderedDict([ + ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), + ('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)), + ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), + ('drop', nn.Dropout(self.drop_rate)), + ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) + + named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+)\.downsample', (0,)), # blocks + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + (r'^norm_pre', (99999,)) + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes=0, global_pool=None): + if global_pool is not None: + self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() + self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm_pre(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( + x = self.head.global_pool(x) + x = self.head.norm(x) + x = self.head.flatten(x) + x = self.head.drop(x) + return x if pre_logits else self.head.fc(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _init_weights(module, name=None, head_init_scale=1.0): + if isinstance(module, nn.Conv2d): + trunc_normal_tf_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Linear): + trunc_normal_tf_(module.weight, std=.02) + nn.init.zeros_(module.bias) + if name and 'head.' in name: + module.weight.data.mul_(head_init_scale) + module.bias.data.mul_(head_init_scale) + + +def checkpoint_filter_fn(state_dict, model): + """ Remap FB checkpoints -> timm """ + if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict: + return state_dict # non-FB checkpoint + + # models were released as train checkpoints... :/ + if 'model_ema' in state_dict: + state_dict = state_dict['model_ema'] + elif 'model' in state_dict: + state_dict = state_dict['model'] + elif 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + + out_dict = {} + import re + for k, v in state_dict.items(): + k = k.replace('downsample_layers.0.', 'stem.') + k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) + k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k) + k = k.replace('dwconv', 'conv_dw') + k = k.replace('pwconv', 'mlp.fc') + k = k.replace('head.', 'head.fc.') + if k.startswith('norm.'): + k = k.replace('norm', 'head.norm') + if v.ndim == 2 and 'head' not in k: + model_shape = model.state_dict()[k].shape + v = v.reshape(model_shape) + out_dict[k] = v + return out_dict + + +def _create_edgenext(variant, pretrained=False, **kwargs): + model = build_model_with_cfg( + EdgeNeXt, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), + **kwargs) + return model + + +@register_model +def edgenext_xx_small(pretrained=False, **kwargs): + # 1.33M & 260.58M @ 256 resolution + # 71.23% Top-1 accuracy + # No AA, Color Jitter=0.4, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=51.66 versus 47.67 for MobileViT_XXS + # For A100: FPS @ BS=1: 212.13 & @ BS=256: 7042.06 versus FPS @ BS=1: 96.68 & @ BS=256: 4624.71 for MobileViT_XXS + model_kwargs = dict(depths=(2, 2, 6, 2), dims=(24, 48, 88, 168), heads=(4, 4, 4, 4), **kwargs) + return _create_edgenext('edgenext_xx_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def edgenext_x_small(pretrained=False, **kwargs): + # 2.34M & 538.0M @ 256 resolution + # 75.00% Top-1 accuracy + # No AA, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=31.61 versus 28.49 for MobileViT_XS + # For A100: FPS @ BS=1: 179.55 & @ BS=256: 4404.95 versus FPS @ BS=1: 94.55 & @ BS=256: 2361.53 for MobileViT_XS + model_kwargs = dict(depths=(3, 3, 9, 3), dims=(32, 64, 100, 192), heads=(4, 4, 4, 4), **kwargs) + return _create_edgenext('edgenext_x_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def edgenext_small(pretrained=False, **kwargs): + # 5.59M & 1260.59M @ 256 resolution + # 79.43% Top-1 accuracy + # AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=20.47 versus 18.86 for MobileViT_S + # For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S + model_kwargs = dict(depths=(3, 3, 9, 3), dims=(48, 96, 160, 304), **kwargs) + return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def edgenext_small_rw(pretrained=False, **kwargs): + # 5.59M & 1260.59M @ 256 resolution + # 79.43% Top-1 accuracy + # AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=20.47 versus 18.86 for MobileViT_S + # For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S + model_kwargs = dict( + depths=(3, 3, 9, 3), dims=(48, 96, 192, 384), + downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs) + return _create_edgenext('edgenext_small_rw', pretrained=pretrained, **model_kwargs) + From 70d6d2c4847982a8f20c4233a28ba84ea9485868 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 2 Jul 2022 15:17:05 -0700 Subject: [PATCH 06/26] support test_crop_size in data config resolve --- timm/data/config.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/timm/data/config.py b/timm/data/config.py index 38f5689a..78176e4b 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -64,11 +64,15 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v new_config['std'] = default_cfg['std'] # resolve default crop percentage - new_config['crop_pct'] = DEFAULT_CROP_PCT + crop_pct = DEFAULT_CROP_PCT if 'crop_pct' in args and args['crop_pct'] is not None: - new_config['crop_pct'] = args['crop_pct'] - elif 'crop_pct' in default_cfg: - new_config['crop_pct'] = default_cfg['crop_pct'] + crop_pct = args['crop_pct'] + else: + if use_test_size and 'test_crop_pct' in default_cfg: + crop_pct = default_cfg['test_crop_pct'] + elif 'crop_pct' in default_cfg: + crop_pct = default_cfg['crop_pct'] + new_config['crop_pct'] = crop_pct if verbose: _logger.info('Data processing configuration for current model + dataset:') From 188c194b0f7bad1aa6c5db46e04c3ef63d2b10e6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 2 Jul 2022 15:17:28 -0700 Subject: [PATCH 07/26] Left some experiment stem code in convnext by mistake --- timm/models/convnext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 662695c7..138e5030 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -434,7 +434,7 @@ def convnext_tiny_hnf(pretrained=False, **kwargs): @register_model def convnext_tiny_hnfd(pretrained=False, **kwargs): model_args = dict( - depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, stem_type='dual', **kwargs) + depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args) return model From c170ba317318599e759d4f004e6ee6aebf1fc258 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 2 Jul 2022 15:18:06 -0700 Subject: [PATCH 08/26] Add weights for resnet10t, resnet14t, and resnetaa50 models. Fix #1314 --- timm/models/resnet.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 476ffe91..28f3cdba 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -35,6 +35,16 @@ def _cfg(url='', **kwargs): default_cfgs = { # ResNet and Wide ResNet + 'resnet10t': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet10t_176_c3-f3215ab1.pth', + input_size=(3, 176, 176), pool_size=(6, 6), + test_crop_pct=0.95, test_input_size=(3, 224, 224), + first_conv='conv1.0'), + 'resnet14t': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet14t_176_c3-c4ed2c37.pth', + input_size=(3, 176, 176), pool_size=(6, 6), + test_crop_pct=0.95, test_input_size=(3, 224, 224), + first_conv='conv1.0'), 'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'), 'resnet18d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet18d_ra2-48a79e06.pth', @@ -262,6 +272,10 @@ default_cfgs = { 'resnetblur101d': _cfg( url='', interpolation='bicubic', first_conv='conv1.0'), + 'resnetaa50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetaa50_a1h-4cf422b3.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0, + interpolation='bicubic', first_conv='conv1.0'), 'resnetaa50d': _cfg( url='', interpolation='bicubic', first_conv='conv1.0'), @@ -1454,6 +1468,14 @@ def resnetblur101d(pretrained=False, **kwargs): return _create_resnet('resnetblur101d', pretrained, **model_args) +@register_model +def resnetaa50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model with avgpool anti-aliasing + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, **kwargs) + return _create_resnet('resnetaa50', pretrained, **model_args) + + @register_model def resnetaa50d(pretrained=False, **kwargs): """Constructs a ResNet-50-D model with avgpool anti-aliasing From 377e9bfa217b60601fb6473022970f115a5455ca Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 2 Jul 2022 15:18:52 -0700 Subject: [PATCH 09/26] Add TPU trained darknet53 weights. Add mising pretrain_cfg for some csp/darknet models. --- timm/models/cspnet.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 095e4701..77473052 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -45,14 +45,18 @@ default_cfgs = { 'cspresnet50w': _cfg(url=''), 'cspresnext50': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnext50_ra_224-648b4713.pth', - input_size=(3, 224, 224), pool_size=(7, 7), crop_pct=0.875 # FIXME I trained this at 224x224, not 256 like ref impl ), 'cspdarknet53': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspdarknet53_ra_256-d05c7c21.pth'), 'darknet17': _cfg(url=''), 'darknet21': _cfg(url=''), - 'darknet53': _cfg(url=''), + 'sedarknet21': _cfg(url=''), + 'darknet53': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknet53_256_c2ns-3aeff817.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0, interpolation='bicubic' + ), + 'darknetaa53': _cfg(url=''), 'cs2darknet_m': _cfg( url=''), From dd9b8f57c4862d4edd87dc1e0a3b34ff005a27f4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 2 Jul 2022 15:20:45 -0700 Subject: [PATCH 10/26] Add feature_info to edgenext for features_only support, hopefully fix some fx / test errors --- timm/models/edgenext.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index 0f8b0464..97971ba6 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -17,8 +17,8 @@ from torch import nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.layers import trunc_normal_tf_ -from timm.models.layers import DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d +from .fx_features import register_notrace_module +from .layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d from .helpers import named_apply, build_model_with_cfg, checkpoint_seq from .registry import register_model @@ -53,6 +53,7 @@ default_cfgs = dict( ) +@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method class PositionalEncodingFourier(nn.Module): def __init__(self, hidden_dim=32, dim=768, temperature=10000): super().__init__() @@ -349,6 +350,7 @@ class EdgeNeXt(nn.Module): self.drop_rate = drop_rate norm_layer = partial(LayerNorm2d, eps=1e-6) norm_layer_cl = partial(nn.LayerNorm, eps=1e-6) + self.feature_info = [] assert stem_type in ('patch', 'overlap') if stem_type == 'patch': @@ -362,14 +364,18 @@ class EdgeNeXt(nn.Module): norm_layer(dims[0]), ) + curr_stride = 4 stages = [] dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] in_chs = dims[0] for i in range(4): + stride = 2 if curr_stride == 2 or i > 0 else 1 + # FIXME support dilation / output_stride + curr_stride *= stride stages.append(EdgeNeXtStage( in_chs=in_chs, out_chs=dims[i], - stride=2 if i > 0 else 1, + stride=stride, depth=depths[i], num_global_blocks=global_block_counts[i], num_heads=heads[i], @@ -385,7 +391,10 @@ class EdgeNeXt(nn.Module): norm_layer_cl=norm_layer_cl, act_layer=act_layer, )) + # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 in_chs = dims[i] + self.feature_info += [dict(num_chs=in_chs, reduction=curr_stride, module=f'stages.{i}')] + self.stages = nn.Sequential(*stages) self.num_features = dims[-1] From d76530582164740e65b6992148d2a755f16cde6b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 2 Jul 2022 15:56:17 -0700 Subject: [PATCH 11/26] Remove first_conv for resnetaa50 def --- timm/models/resnet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 28f3cdba..e5a6b791 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -274,8 +274,7 @@ default_cfgs = { interpolation='bicubic', first_conv='conv1.0'), 'resnetaa50': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetaa50_a1h-4cf422b3.pth', - test_input_size=(3, 288, 288), test_crop_pct=1.0, - interpolation='bicubic', first_conv='conv1.0'), + test_input_size=(3, 288, 288), test_crop_pct=1.0, interpolation='bicubic'), 'resnetaa50d': _cfg( url='', interpolation='bicubic', first_conv='conv1.0'), From d0c5bd57223c3f1da58219f497fe48d478f873da Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 3 Jul 2022 08:32:41 -0700 Subject: [PATCH 12/26] Rename cs2->cs3 for darknets. Fix features_only for cs3 darknets. --- timm/models/cspnet.py | 62 ++++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 77473052..4591f101 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -58,13 +58,13 @@ default_cfgs = { ), 'darknetaa53': _cfg(url=''), - 'cs2darknet_m': _cfg( + 'cs3darknet_m': _cfg( url=''), - 'cs2darknet_l': _cfg( + 'cs3darknet_l': _cfg( url=''), - 'cs2darknet_f_m': _cfg( + 'cs3darknet_focus_m': _cfg( url=''), - 'cs2darknet_f_l': _cfg( + 'cs3darknet_focus_l': _cfg( url=''), } @@ -185,7 +185,7 @@ model_cfgs = dict( ), ), - cs2darknet_m=dict( + cs3darknet_m=dict( stem=dict(out_chs=(24, 48), kernel_size=3, stride=2, pool=''), stage=dict( out_chs=(96, 192, 384, 768), @@ -196,12 +196,11 @@ model_cfgs = dict( avg_down=False, ), ), - - cs2darknet_f_m=dict( - stem=dict(out_chs=48, kernel_size=6, stride=2, padding=2, pool=''), + cs3darknet_l=dict( + stem=dict(out_chs=(32, 64), kernel_size=3, stride=2, pool=''), stage=dict( - out_chs=(96, 192, 384, 768), - depth=(2, 4, 6, 2), + out_chs=(128, 256, 512, 1024), + depth=(3, 6, 9, 3), stride=(2,) * 4, bottle_ratio=(1.,) * 4, block_ratio=(0.5,) * 4, @@ -209,19 +208,18 @@ model_cfgs = dict( ), ), - cs2darknet_l=dict( - stem=dict(out_chs=(32, 64), kernel_size=3, stride=2, pool=''), + cs3darknet_focus_m=dict( + stem=dict(out_chs=48, kernel_size=6, stride=2, padding=2, pool=''), stage=dict( - out_chs=(128, 256, 512, 1024), - depth=(3, 6, 9, 3), + out_chs=(96, 192, 384, 768), + depth=(2, 4, 6, 2), stride=(2,) * 4, bottle_ratio=(1.,) * 4, block_ratio=(0.5,) * 4, avg_down=False, ), ), - - cs2darknet_f_l=dict( + cs3darknet_focus_l=dict( stem=dict(out_chs=64, kernel_size=6, stride=2, padding=2, pool=''), stage=dict( out_chs=(128, 256, 512, 1024), @@ -438,9 +436,9 @@ class CrossStage(nn.Module): return out -class CrossStage2(nn.Module): - """Cross Stage v2. - Similar to CrossStage, but with one transition conv for the concat output. +class CrossStage3(nn.Module): + """Cross Stage 3. + Similar to CrossStage, but with only one transition conv for the output. """ def __init__( self, @@ -461,7 +459,7 @@ class CrossStage2(nn.Module): block_fn=ResBottleneck, **block_kwargs ): - super(CrossStage2, self).__init__() + super(CrossStage3, self).__init__() first_dilation = first_dilation or dilation down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels self.exp_chs = exp_chs = int(round(out_chs * exp_ratio)) @@ -696,8 +694,12 @@ def _init_weights(module, name, zero_init_last=False): def _create_cspnet(variant, pretrained=False, **kwargs): - # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5] - out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4, 5) if 'darknet' in variant else (0, 1, 2, 3, 4)) + if variant.startswith('darknet') or variant.startswith('cspdarknet'): + # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5] + default_out_indices = (0, 1, 2, 3, 4, 5) + else: + default_out_indices = (0, 1, 2, 3, 4) + out_indices = kwargs.pop('out_indices', default_out_indices) return build_model_with_cfg( CspNet, variant, pretrained, model_cfg=model_cfgs[variant], @@ -757,24 +759,24 @@ def darknetaa53(pretrained=False, **kwargs): @register_model -def cs2darknet_m(pretrained=False, **kwargs): +def cs3darknet_m(pretrained=False, **kwargs): return _create_cspnet( - 'cs2darknet_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs) + 'cs3darknet_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) @register_model -def cs2darknet_l(pretrained=False, **kwargs): +def cs3darknet_l(pretrained=False, **kwargs): return _create_cspnet( - 'cs2darknet_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs) + 'cs3darknet_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) @register_model -def cs2darknet_f_m(pretrained=False, **kwargs): +def cs3darknet_focus_m(pretrained=False, **kwargs): return _create_cspnet( - 'cs2darknet_f_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs) + 'cs3darknet_focus_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) @register_model -def cs2darknet_f_l(pretrained=False, **kwargs): +def cs3darknet_focus_l(pretrained=False, **kwargs): return _create_cspnet( - 'cs2darknet_f_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs) \ No newline at end of file + 'cs3darknet_focus_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) \ No newline at end of file From 7d4b3807d5c40b0f8d7e66d27a7672684e482996 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 4 Jul 2022 22:25:22 -0700 Subject: [PATCH 13/26] Support DeiT-3 (Revenge of the ViT) checkpoints. Add non-overlapping (w/ class token) pos-embed support to vit. --- timm/models/deit.py | 204 +++++++++++++++++++++++++++++- timm/models/vision_transformer.py | 64 +++++++--- 2 files changed, 247 insertions(+), 21 deletions(-) diff --git a/timm/models/deit.py b/timm/models/deit.py index e6b4b025..a2f43b91 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -1,7 +1,10 @@ """ DeiT - Data-efficient Image Transformers DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below -paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +paper: `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +paper: `DeiT III: Revenge of the ViT` - https://arxiv.org/abs/2204.07118 Modifications copyright 2021, Ross Wightman """ @@ -53,6 +56,46 @@ default_cfgs = { url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')), + + 'deit3_small_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'), + 'deit3_small_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_base_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'), + 'deit3_base_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_large_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'), + 'deit3_large_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_huge_patch14_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'), + + 'deit3_small_patch16_224_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth', + crop_pct=1.0), + 'deit3_small_patch16_384_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_base_patch16_224_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth', + crop_pct=1.0), + 'deit3_base_patch16_384_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_large_patch16_224_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth', + crop_pct=1.0), + 'deit3_large_patch16_384_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_huge_patch14_224_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth', + crop_pct=1.0), } @@ -68,9 +111,10 @@ class VisionTransformerDistilled(VisionTransformer): super().__init__(*args, **kwargs, weight_init='skip') assert self.global_pool in ('token',) - self.num_tokens = 2 + self.num_prefix_tokens = 2 self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim)) self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() self.distilled_training = False # must set this True to train w/ distillation token @@ -220,3 +264,157 @@ def deit_base_distilled_patch16_384(pretrained=False, **kwargs): model = _create_deit( 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) return model + + +@register_model +def deit3_small_patch16_224(pretrained=False, **kwargs): + """ DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_small_patch16_384(pretrained=False, **kwargs): + """ DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_small_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_base_patch16_224(pretrained=False, **kwargs): + """ DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_base_patch16_384(pretrained=False, **kwargs): + """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_base_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_large_patch16_224(pretrained=False, **kwargs): + """ DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_large_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_large_patch16_384(pretrained=False, **kwargs): + """ DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_large_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_huge_patch14_224(pretrained=False, **kwargs): + """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_huge_patch14_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_small_patch16_224_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_small_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_small_patch16_384_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_small_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_base_patch16_224_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_base_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_base_patch16_384_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_base_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_large_patch16_224_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_large_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_large_patch16_384_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_large_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_huge_patch14_224_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_huge_patch14_224_in21ft1k', pretrained=pretrained, **model_kwargs) + return model diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 8551feae..022052d0 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -325,8 +325,8 @@ class VisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, - class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', - embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): + class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): """ Args: img_size (int, tuple): input image size @@ -360,15 +360,17 @@ class VisionTransformer(nn.Module): self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 1 if class_token else 0 + self.num_prefix_tokens = 1 if class_token else 0 + self.no_embed_class = no_embed_class self.grad_checkpointing = False self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None - self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule @@ -428,11 +430,24 @@ class VisionTransformer(nn.Module): self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + def _pos_embed(self, x): + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + self.pos_embed + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.pos_embed + return self.pos_drop(x) + def forward_features(self, x): x = self.patch_embed(x) - if self.cls_token is not None: - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - x = self.pos_drop(x + self.pos_embed) + x = self._pos_embed(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: @@ -442,7 +457,7 @@ class VisionTransformer(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool: - x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) return x if pre_logits else self.head(x) @@ -556,7 +571,11 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) if pos_embed_w.shape != model.pos_embed.shape: pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights - pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + pos_embed_w, + model.pos_embed, + getattr(model, 'num_prefix_tokens', 1), + model.patch_embed.grid_size + ) model.pos_embed.copy_(pos_embed_w) model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) @@ -585,16 +604,16 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) -def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): +def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) ntok_new = posemb_new.shape[1] - if num_tokens: - posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] - ntok_new -= num_tokens + if num_prefix_tokens: + posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:] + ntok_new -= num_prefix_tokens else: - posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] gs_old = int(math.sqrt(len(posemb_grid))) if not len(gs_new): # backwards compatibility gs_new = [int(math.sqrt(ntok_new))] * 2 @@ -603,25 +622,34 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) - posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) return posemb def checkpoint_filter_fn(state_dict, model): """ convert patch embedding weight from manual patchify + linear proj to conv""" + import re out_dict = {} if 'model' in state_dict: # For deit models state_dict = state_dict['model'] + for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k and len(v.shape) < 4: # For old models that I trained prior to conv based patchification O, I, H, W = model.patch_embed.proj.weight.shape v = v.reshape(O, -1, H, W) - elif k == 'pos_embed' and v.shape != model.pos_embed.shape: + elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: # To resize pos embedding when using model at different size from pretrained weights v = resize_pos_embed( - v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + v, + model.pos_embed, + getattr(model, 'num_prefix_tokens', 1), + model.patch_embed.grid_size + ) + elif 'gamma_' in k: + # remap layer-scale gamma into sub-module (deit3 models) + k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k) elif 'pre_logits' in k: # NOTE representation layer removed as not used in latest 21k/1k pretrained weights continue From bfc0dccb0ed1026f596797818ab865ea53ef3d2c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 14:23:20 -0700 Subject: [PATCH 14/26] Improve image extension handling, add methods to modify / get defaults. Fix #1335 fix #1274. --- timm/data/__init__.py | 5 ++- timm/data/parsers/__init__.py | 1 + timm/data/parsers/constants.py | 1 - timm/data/parsers/img_extensions.py | 50 ++++++++++++++++++++++++ timm/data/parsers/parser_factory.py | 1 - timm/data/parsers/parser_image_folder.py | 29 ++++++++++++-- timm/data/parsers/parser_image_in_tar.py | 29 ++++++++------ timm/data/parsers/parser_image_tar.py | 10 +++-- 8 files changed, 103 insertions(+), 23 deletions(-) delete mode 100644 timm/data/parsers/constants.py create mode 100644 timm/data/parsers/img_extensions.py diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 7d3cb2b4..0eb10a66 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -6,7 +6,8 @@ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset_factory import create_dataset from .loader import create_loader from .mixup import Mixup, FastCollateMixup -from .parsers import create_parser +from .parsers import create_parser,\ + get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions from .real_labels import RealLabelsImagenet from .transforms import * -from .transforms_factory import create_transform \ No newline at end of file +from .transforms_factory import create_transform diff --git a/timm/data/parsers/__init__.py b/timm/data/parsers/__init__.py index eeb44e37..4e820d5e 100644 --- a/timm/data/parsers/__init__.py +++ b/timm/data/parsers/__init__.py @@ -1 +1,2 @@ from .parser_factory import create_parser +from .img_extensions import * diff --git a/timm/data/parsers/constants.py b/timm/data/parsers/constants.py deleted file mode 100644 index e7ba484e..00000000 --- a/timm/data/parsers/constants.py +++ /dev/null @@ -1 +0,0 @@ -IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') diff --git a/timm/data/parsers/img_extensions.py b/timm/data/parsers/img_extensions.py new file mode 100644 index 00000000..45c85aab --- /dev/null +++ b/timm/data/parsers/img_extensions.py @@ -0,0 +1,50 @@ +from copy import deepcopy + +__all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions'] + + +IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use +_IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync + + +def _set_extensions(extensions): + global IMG_EXTENSIONS + global _IMG_EXTENSIONS_SET + dedupe = set() # NOTE de-duping tuple while keeping original order + IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x)) + _IMG_EXTENSIONS_SET = set(extensions) + + +def _valid_extension(x: str): + return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.') + + +def is_img_extension(ext): + return ext in _IMG_EXTENSIONS_SET + + +def get_img_extensions(as_set=False): + return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS) + + +def set_img_extensions(extensions): + assert len(extensions) + for x in extensions: + assert _valid_extension(x) + _set_extensions(extensions) + + +def add_img_extensions(ext): + if not isinstance(ext, (list, tuple, set)): + ext = (ext,) + for x in ext: + assert _valid_extension(x) + extensions = IMG_EXTENSIONS + tuple(ext) + _set_extensions(extensions) + + +def del_img_extensions(ext): + if not isinstance(ext, (list, tuple, set)): + ext = (ext,) + extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext) + _set_extensions(extensions) diff --git a/timm/data/parsers/parser_factory.py b/timm/data/parsers/parser_factory.py index 892090ad..0665c02a 100644 --- a/timm/data/parsers/parser_factory.py +++ b/timm/data/parsers/parser_factory.py @@ -1,7 +1,6 @@ import os from .parser_image_folder import ParserImageFolder -from .parser_image_tar import ParserImageTar from .parser_image_in_tar import ParserImageInTar diff --git a/timm/data/parsers/parser_image_folder.py b/timm/data/parsers/parser_image_folder.py index ed349009..3d22a17b 100644 --- a/timm/data/parsers/parser_image_folder.py +++ b/timm/data/parsers/parser_image_folder.py @@ -6,15 +6,35 @@ on the folder hierarchy, just leaf folders by default. Hacked together by / Copyright 2020 Ross Wightman """ import os +from typing import Dict, List, Optional, Set, Tuple, Union from timm.utils.misc import natural_key -from .parser import Parser from .class_map import load_class_map -from .constants import IMG_EXTENSIONS +from .img_extensions import get_img_extensions +from .parser import Parser + + +def find_images_and_targets( + folder: str, + types: Optional[Union[List, Tuple, Set]] = None, + class_to_idx: Optional[Dict] = None, + leaf_name_only: bool = True, + sort: bool = True +): + """ Walk folder recursively to discover images and map them to classes by folder names. + Args: + folder: root of folder to recrusively search + types: types (file extensions) to search for in path + class_to_idx: specify mapping for class (folder name) to class index if set + leaf_name_only: use only leaf-name of folder walk for class names + sort: re-sort found images by name (for consistent ordering) -def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): + Returns: + A list of image and target tuples, class_to_idx mapping + """ + types = get_img_extensions(as_set=True) if not types else set(types) labels = [] filenames = [] for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): @@ -51,7 +71,8 @@ class ParserImageFolder(Parser): self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) if len(self.samples) == 0: raise RuntimeError( - f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}') + f'Found 0 images in subfolders of {root}. ' + f'Supported image extensions are {", ".join(get_img_extensions())}') def __getitem__(self, index): path, target = self.samples[index] diff --git a/timm/data/parsers/parser_image_in_tar.py b/timm/data/parsers/parser_image_in_tar.py index c6ada962..4fcad797 100644 --- a/timm/data/parsers/parser_image_in_tar.py +++ b/timm/data/parsers/parser_image_in_tar.py @@ -9,20 +9,20 @@ Labels are based on the combined folder and/or tar name structure. Hacked together by / Copyright 2020 Ross Wightman """ +import logging import os -import tarfile import pickle -import logging -import numpy as np +import tarfile from glob import glob -from typing import List, Dict +from typing import List, Tuple, Dict, Set, Optional, Union + +import numpy as np from timm.utils.misc import natural_key -from .parser import Parser from .class_map import load_class_map -from .constants import IMG_EXTENSIONS - +from .img_extensions import get_img_extensions +from .parser import Parser _logger = logging.getLogger(__name__) CACHE_FILENAME_SUFFIX = '_tarinfos.pickle' @@ -39,7 +39,7 @@ class TarState: self.tf = None -def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS): +def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions: Set[str]): sample_count = 0 for i, ti in enumerate(tf): if not ti.isfile(): @@ -60,7 +60,14 @@ def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTE return sample_count -def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True): +def extract_tarinfos( + root, + class_name_to_idx: Optional[Dict] = None, + cache_tarinfo: Optional[bool] = None, + extensions: Optional[Union[List, Tuple, Set]] = None, + sort: bool = True +): + extensions = get_img_extensions(as_set=True) if not extensions else set(extensions) root_is_tar = False if os.path.isfile(root): assert os.path.splitext(root)[-1].lower() == '.tar' @@ -176,8 +183,8 @@ class ParserImageInTar(Parser): self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos( self.root, class_name_to_idx=class_name_to_idx, - cache_tarinfo=cache_tarinfo, - extensions=IMG_EXTENSIONS) + cache_tarinfo=cache_tarinfo + ) self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()} if len(tarfiles) == 1 and tarfiles[0][0] is None: self.root_is_tar = True diff --git a/timm/data/parsers/parser_image_tar.py b/timm/data/parsers/parser_image_tar.py index 467537f4..c2ed429d 100644 --- a/timm/data/parsers/parser_image_tar.py +++ b/timm/data/parsers/parser_image_tar.py @@ -8,13 +8,15 @@ Hacked together by / Copyright 2020 Ross Wightman import os import tarfile -from .parser import Parser -from .class_map import load_class_map -from .constants import IMG_EXTENSIONS from timm.utils.misc import natural_key +from .class_map import load_class_map +from .img_extensions import get_img_extensions +from .parser import Parser + def extract_tarinfo(tarfile, class_to_idx=None, sort=True): + extensions = get_img_extensions(as_set=True) files = [] labels = [] for ti in tarfile.getmembers(): @@ -23,7 +25,7 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True): dirname, basename = os.path.split(ti.path) label = os.path.basename(dirname) ext = os.path.splitext(basename)[1] - if ext.lower() in IMG_EXTENSIONS: + if ext.lower() in extensions: files.append(ti) labels.append(label) if class_to_idx is None: From 06307b8b41da5783f38167f8ab609f83fb6b351d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 14:37:58 -0700 Subject: [PATCH 15/26] Remove experimental downsample in block support in ConvNeXt. Experiment further before keeping it in. --- timm/models/convnext.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 138e5030..be0c9a66 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -17,7 +17,6 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_notrace_module from .helpers import named_apply, build_model_with_cfg, checkpoint_seq from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d from .registry import register_model @@ -124,7 +123,6 @@ class ConvNeXtBlock(nn.Module): norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) mlp_layer = ConvMlp if conv_mlp else Mlp self.use_conv_mlp = conv_mlp - self.shortcut_after_dw = stride > 1 self.conv_dw = create_conv2d(dim, dim_out, kernel_size=7, stride=stride, depthwise=True, bias=conv_bias) self.norm = norm_layer(dim_out) @@ -135,9 +133,6 @@ class ConvNeXtBlock(nn.Module): def forward(self, x): shortcut = x x = self.conv_dw(x) - if self.shortcut_after_dw: - shortcut = x - if self.use_conv_mlp: x = self.norm(x) x = self.mlp(x) @@ -150,7 +145,6 @@ class ConvNeXtBlock(nn.Module): x = x.mul(self.gamma.reshape(1, -1, 1, 1)) x = self.drop_path(x) + shortcut - #print('b', x.shape) return x @@ -164,7 +158,6 @@ class ConvNeXtStage(nn.Module): depth=2, drop_path_rates=None, ls_init_value=1.0, - downsample_block=False, conv_mlp=False, conv_bias=True, norm_layer=None, @@ -173,14 +166,14 @@ class ConvNeXtStage(nn.Module): super().__init__() self.grad_checkpointing = False - if downsample_block or (in_chs == out_chs and stride == 1): - self.downsample = nn.Identity() - else: + if in_chs != out_chs or stride > 1: self.downsample = nn.Sequential( norm_layer(in_chs), nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride, bias=conv_bias), ) in_chs = out_chs + else: + self.downsample = nn.Identity() drop_path_rates = drop_path_rates or [0.] * depth stage_blocks = [] @@ -188,7 +181,6 @@ class ConvNeXtStage(nn.Module): stage_blocks.append(ConvNeXtBlock( dim=in_chs, dim_out=out_chs, - stride=stride if downsample_block and i == 0 else 1, drop_path=drop_path_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp, @@ -236,7 +228,6 @@ class ConvNeXt(nn.Module): stem_stride=4, head_init_scale=1., head_norm_first=False, - downsample_block=False, conv_mlp=False, conv_bias=True, norm_layer=None, @@ -291,7 +282,6 @@ class ConvNeXt(nn.Module): depth=depths[i], drop_path_rates=dp_rates[i], ls_init_value=ls_init_value, - downsample_block=downsample_block, conv_mlp=conv_mlp, conv_bias=conv_bias, norm_layer=norm_layer, @@ -418,7 +408,7 @@ def convnext_nano_hnf(pretrained=False, **kwargs): @register_model def convnext_nano_ols(pretrained=False, **kwargs): model_args = dict( - depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), downsample_block=True, + depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, conv_bias=False, stem_type='overlap', stem_kernel_size=9, **kwargs) model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args) return model @@ -426,7 +416,8 @@ def convnext_nano_ols(pretrained=False, **kwargs): @register_model def convnext_tiny_hnf(pretrained=False, **kwargs): - model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) + model_args = dict( + depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args) return model From eca09b86423d8f441e55f27205efa1b3c9e77d41 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 14:41:01 -0700 Subject: [PATCH 16/26] Add MobileVitV2 support. Fix #1332. Move GroupNorm1 to common layers (used in poolformer + mobilevitv2). Keep ol custom ConvNeXt LayerNorm2d impl as LayerNormExp2d for reference. --- timm/models/layers/__init__.py | 2 +- timm/models/layers/create_attn.py | 2 +- timm/models/layers/norm.py | 54 +++- timm/models/mobilevit.py | 443 +++++++++++++++++++++++++++++- timm/models/poolformer.py | 11 +- 5 files changed, 489 insertions(+), 23 deletions(-) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index b1f452ff..b9eeec0f 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -25,7 +25,7 @@ from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn -from .norm import GroupNorm, LayerNorm2d +from .norm import GroupNorm, GroupNorm1, LayerNorm2d from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm from .padding import get_padding, get_same_padding, pad_same from .patch_embed import PatchEmbed diff --git a/timm/models/layers/create_attn.py b/timm/models/layers/create_attn.py index 028c0f75..cc7e91ea 100644 --- a/timm/models/layers/create_attn.py +++ b/timm/models/layers/create_attn.py @@ -22,7 +22,7 @@ def get_attn(attn_type): if isinstance(attn_type, torch.nn.Module): return attn_type module_cls = None - if attn_type is not None: + if attn_type: if isinstance(attn_type, str): attn_type = attn_type.lower() # Lightweight attention modules (channel and/or coarse spatial). diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index 345f67bc..1677dbfa 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -14,11 +14,59 @@ class GroupNorm(nn.GroupNorm): return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) +class GroupNorm1(nn.GroupNorm): + """ Group Normalization with 1 group. + Input: tensor in shape [B, C, *] + """ + + def __init__(self, num_channels, **kwargs): + super().__init__(1, num_channels, **kwargs) + + class LayerNorm2d(nn.LayerNorm): - """ LayerNorm for channels of '2D' spatial BCHW tensors """ - def __init__(self, num_channels, eps=1e-6): - super().__init__(num_channels, eps=eps) + """ LayerNorm for channels of '2D' spatial NCHW tensors """ + def __init__(self, num_channels, eps=1e-6, affine=True): + super().__init__(num_channels, eps=eps, elementwise_affine=affine) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.layer_norm( x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + + +def _is_contiguous(tensor: torch.Tensor) -> bool: + # jit is oh so lovely :/ + # if torch.jit.is_tracing(): + # return True + if torch.jit.is_scripting(): + return tensor.is_contiguous() + else: + return tensor.is_contiguous(memory_format=torch.contiguous_format) + + +@torch.jit.script +def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): + s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) + x = (x - u) * torch.rsqrt(s + eps) + x = x * weight[:, None, None] + bias[:, None, None] + return x + + +class LayerNormExp2d(nn.LayerNorm): + """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). + + Experimental implementation w/ manual norm for tensors non-contiguous tensors. + + This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last + layout. However, benefits are not always clear and can perform worse on other GPUs. + """ + + def __init__(self, num_channels, eps=1e-6): + super().__init__(num_channels, eps=eps) + + def forward(self, x) -> torch.Tensor: + if _is_contiguous(x): + x = F.layer_norm( + x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + else: + x = _layer_norm_cf(x, self.weight, self.bias, self.eps) + return x diff --git a/timm/models/mobilevit.py b/timm/models/mobilevit.py index 1c55bd1c..2a3ab924 100644 --- a/timm/models/mobilevit.py +++ b/timm/models/mobilevit.py @@ -1,7 +1,8 @@ """ MobileViT Paper: -`MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178 +V1: `MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178 +V2: `Separable Self-attention for Mobile Vision Transformers` - https://arxiv.org/abs/2206.02680 MobileVitBlock and checkpoints adapted from https://github.com/apple/ml-cvnets (original copyright below) License: https://github.com/apple/ml-cvnets/blob/main/LICENSE (Apple open source) @@ -13,7 +14,7 @@ Rest of code, ByobNet, and Transformer block hacked together by / Copyright 2022 # Copyright (C) 2020 Apple Inc. All Rights Reserved. # import math -from typing import Union, Callable, Dict, Tuple, Optional +from typing import Union, Callable, Dict, Tuple, Optional, Sequence import torch from torch import nn @@ -21,7 +22,7 @@ import torch.nn.functional as F from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups from .fx_features import register_notrace_module -from .layers import to_2tuple, make_divisible +from .layers import to_2tuple, make_divisible, LayerNorm2d, GroupNorm1, ConvMlp, DropPath from .vision_transformer import Block as TransformerBlock from .helpers import build_model_with_cfg from .registry import register_model @@ -48,6 +49,48 @@ default_cfgs = { 'mobilevit_s': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_s-38a5a959.pth'), 'semobilevit_s': _cfg(), + + 'mobilevitv2_050': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_050-49951ee2.pth', + crop_pct=0.888), + 'mobilevitv2_075': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_075-b5556ef6.pth', + crop_pct=0.888), + 'mobilevitv2_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_100-e464ef3b.pth', + crop_pct=0.888), + 'mobilevitv2_125': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_125-0ae35027.pth', + crop_pct=0.888), + 'mobilevitv2_150': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150-737c5019.pth', + crop_pct=0.888), + 'mobilevitv2_175': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175-16462ee2.pth', + crop_pct=0.888), + 'mobilevitv2_200': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200-b3422f67.pth', + crop_pct=0.888), + + 'mobilevitv2_150_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_in22ft1k-0b555d7b.pth', + crop_pct=0.888), + 'mobilevitv2_175_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_in22ft1k-4117fa1f.pth', + crop_pct=0.888), + 'mobilevitv2_200_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_in22ft1k-1d7c8927.pth', + crop_pct=0.888), + + 'mobilevitv2_150_384_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_384_in22ft1k-9e142854.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'mobilevitv2_175_384_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_384_in22ft1k-059cbe56.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'mobilevitv2_200_384_in22ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_384_in22ft1k-32c87503.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), } @@ -72,6 +115,40 @@ def _mobilevit_block(d, c, s, transformer_dim, transformer_depth, patch_size=4, ) +def _mobilevitv2_block(d, c, s, transformer_depth, patch_size=2, br=2.0, transformer_br=0.5): + # inverted residual + mobilevit blocks as per MobileViT network + return ( + _inverted_residual_block(d=d, c=c, s=s, br=br), + ByoBlockCfg( + type='mobilevit2', d=1, c=c, s=1, br=transformer_br, gs=1, + block_kwargs=dict( + transformer_depth=transformer_depth, + patch_size=patch_size) + ) + ) + + +def _mobilevitv2_cfg(multiplier=1.0): + chs = (64, 128, 256, 384, 512) + if multiplier != 1.0: + chs = tuple([int(c * multiplier) for c in chs]) + cfg = ByoModelCfg( + blocks=( + _inverted_residual_block(d=1, c=chs[0], s=1, br=2.0), + _inverted_residual_block(d=2, c=chs[1], s=2, br=2.0), + _mobilevitv2_block(d=1, c=chs[2], s=2, transformer_depth=2), + _mobilevitv2_block(d=1, c=chs[3], s=2, transformer_depth=4), + _mobilevitv2_block(d=1, c=chs[4], s=2, transformer_depth=3), + ), + stem_chs=int(32 * multiplier), + stem_type='3x3', + stem_pool='', + downsample='', + act_layer='silu', + ) + return cfg + + model_cfgs = dict( mobilevit_xxs=ByoModelCfg( blocks=( @@ -137,11 +214,19 @@ model_cfgs = dict( attn_kwargs=dict(rd_ratio=1/8), num_features=640, ), + + mobilevitv2_050=_mobilevitv2_cfg(.50), + mobilevitv2_075=_mobilevitv2_cfg(.75), + mobilevitv2_125=_mobilevitv2_cfg(1.25), + mobilevitv2_100=_mobilevitv2_cfg(1.0), + mobilevitv2_150=_mobilevitv2_cfg(1.5), + mobilevitv2_175=_mobilevitv2_cfg(1.75), + mobilevitv2_200=_mobilevitv2_cfg(2.0), ) @register_notrace_module -class MobileViTBlock(nn.Module): +class MobileVitBlock(nn.Module): """ MobileViT block Paper: https://arxiv.org/abs/2110.02178?context=cs.LG """ @@ -165,9 +250,9 @@ class MobileViTBlock(nn.Module): drop_path_rate: float = 0., layers: LayerFn = None, transformer_norm_layer: Callable = nn.LayerNorm, - downsample: str = '' + **kwargs, # eat unused args ): - super(MobileViTBlock, self).__init__() + super(MobileVitBlock, self).__init__() layers = layers or LayerFn() groups = num_groups(group_size, in_chs) @@ -241,7 +326,270 @@ class MobileViTBlock(nn.Module): return x -register_block('mobilevit', MobileViTBlock) +class LinearSelfAttention(nn.Module): + """ + This layer applies a self-attention with linear complexity, as described in `https://arxiv.org/abs/2206.02680` + This layer can be used for self- as well as cross-attention. + Args: + embed_dim (int): :math:`C` from an expected input of size :math:`(N, C, H, W)` + attn_drop (float): Dropout value for context scores. Default: 0.0 + bias (bool): Use bias in learnable layers. Default: True + Shape: + - Input: :math:`(N, C, P, N)` where :math:`N` is the batch size, :math:`C` is the input channels, + :math:`P` is the number of pixels in the patch, and :math:`N` is the number of patches + - Output: same as the input + .. note:: + For MobileViTv2, we unfold the feature map [B, C, H, W] into [B, C, P, N] where P is the number of pixels + in a patch and N is the number of patches. Because channel is the first dimension in this unfolded tensor, + we use point-wise convolution (instead of a linear layer). This avoids a transpose operation (which may be + expensive on resource-constrained devices) that may be required to convert the unfolded tensor from + channel-first to channel-last format in case of a linear layer. + """ + + def __init__( + self, + embed_dim: int, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + + self.qkv_proj = nn.Conv2d( + in_channels=embed_dim, + out_channels=1 + (2 * embed_dim), + bias=bias, + kernel_size=1, + ) + self.attn_drop = nn.Dropout(attn_drop) + self.out_proj = nn.Conv2d( + in_channels=embed_dim, + out_channels=embed_dim, + bias=bias, + kernel_size=1, + ) + self.out_drop = nn.Dropout(proj_drop) + + def _forward_self_attn(self, x: torch.Tensor) -> torch.Tensor: + # [B, C, P, N] --> [B, h + 2d, P, N] + qkv = self.qkv_proj(x) + + # Project x into query, key and value + # Query --> [B, 1, P, N] + # value, key --> [B, d, P, N] + query, key, value = qkv.split([1, self.embed_dim, self.embed_dim], dim=1) + + # apply softmax along N dimension + context_scores = F.softmax(query, dim=-1) + context_scores = self.attn_drop(context_scores) + + # Compute context vector + # [B, d, P, N] x [B, 1, P, N] -> [B, d, P, N] --> [B, d, P, 1] + context_vector = (key * context_scores).sum(dim=-1, keepdim=True) + + # combine context vector with values + # [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N] + out = F.relu(value) * context_vector.expand_as(value) + out = self.out_proj(out) + out = self.out_drop(out) + return out + + @torch.jit.ignore() + def _forward_cross_attn(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor: + # x --> [B, C, P, N] + # x_prev = [B, C, P, M] + batch_size, in_dim, kv_patch_area, kv_num_patches = x.shape + q_patch_area, q_num_patches = x.shape[-2:] + + assert ( + kv_patch_area == q_patch_area + ), "The number of pixels in a patch for query and key_value should be the same" + + # compute query, key, and value + # [B, C, P, M] --> [B, 1 + d, P, M] + qk = F.conv2d( + x_prev, + weight=self.qkv_proj.weight[:self.embed_dim + 1], + bias=self.qkv_proj.bias[:self.embed_dim + 1], + ) + + # [B, 1 + d, P, M] --> [B, 1, P, M], [B, d, P, M] + query, key = qk.split([1, self.embed_dim], dim=1) + # [B, C, P, N] --> [B, d, P, N] + value = F.conv2d( + x, + weight=self.qkv_proj.weight[self.embed_dim + 1], + bias=self.qkv_proj.bias[self.embed_dim + 1] if self.qkv_proj.bias is not None else None, + ) + + # apply softmax along M dimension + context_scores = F.softmax(query, dim=-1) + context_scores = self.attn_drop(context_scores) + + # compute context vector + # [B, d, P, M] * [B, 1, P, M] -> [B, d, P, M] --> [B, d, P, 1] + context_vector = (key * context_scores).sum(dim=-1, keepdim=True) + + # combine context vector with values + # [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N] + out = F.relu(value) * context_vector.expand_as(value) + out = self.out_proj(out) + out = self.out_drop(out) + return out + + def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor: + if x_prev is None: + return self._forward_self_attn(x) + else: + return self._forward_cross_attn(x, x_prev=x_prev) + + +class LinearTransformerBlock(nn.Module): + """ + This class defines the pre-norm transformer encoder with linear self-attention in `MobileViTv2 paper <>`_ + Args: + embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, P, N)` + mlp_ratio (float): Inner dimension ratio of the FFN relative to embed_dim + drop (float): Dropout rate. Default: 0.0 + attn_drop (float): Dropout rate for attention in multi-head attention. Default: 0.0 + drop_path (float): Stochastic depth rate Default: 0.0 + norm_layer (Callable): Normalization layer. Default: layer_norm_2d + Shape: + - Input: :math:`(B, C_{in}, P, N)` where :math:`B` is batch size, :math:`C_{in}` is input embedding dim, + :math:`P` is number of pixels in a patch, and :math:`N` is number of patches, + - Output: same shape as the input + """ + + def __init__( + self, + embed_dim: int, + mlp_ratio: float = 2.0, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + act_layer=None, + norm_layer=None, + ) -> None: + super().__init__() + act_layer = act_layer or nn.SiLU + norm_layer = norm_layer or GroupNorm1 + + self.norm1 = norm_layer(embed_dim) + self.attn = LinearSelfAttention(embed_dim=embed_dim, attn_drop=attn_drop, proj_drop=drop) + self.drop_path1 = DropPath(drop_path) + + self.norm2 = norm_layer(embed_dim) + self.mlp = ConvMlp( + in_features=embed_dim, + hidden_features=int(embed_dim * mlp_ratio), + act_layer=act_layer, + drop=drop) + self.drop_path2 = DropPath(drop_path) + + def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor: + if x_prev is None: + # self-attention + x = x + self.drop_path1(self.attn(self.norm1(x))) + else: + # cross-attention + res = x + x = self.norm1(x) # norm + x = self.attn(x, x_prev) # attn + x = self.drop_path1(x) + res # residual + + # Feed forward network + x = x + self.drop_path2(self.mlp(self.norm2(x))) + return x + + +@register_notrace_module +class MobileVitV2Block(nn.Module): + """ + This class defines the `MobileViTv2 block <>`_ + """ + + def __init__( + self, + in_chs: int, + out_chs: Optional[int] = None, + kernel_size: int = 3, + bottle_ratio: float = 1.0, + group_size: Optional[int] = 1, + dilation: Tuple[int, int] = (1, 1), + mlp_ratio: float = 2.0, + transformer_dim: Optional[int] = None, + transformer_depth: int = 2, + patch_size: int = 8, + attn_drop: float = 0., + drop: int = 0., + drop_path_rate: float = 0., + layers: LayerFn = None, + transformer_norm_layer: Callable = GroupNorm1, + **kwargs, # eat unused args + ): + super(MobileVitV2Block, self).__init__() + layers = layers or LayerFn() + groups = num_groups(group_size, in_chs) + out_chs = out_chs or in_chs + transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs) + + self.conv_kxk = layers.conv_norm_act( + in_chs, in_chs, kernel_size=kernel_size, + stride=1, groups=groups, dilation=dilation[0]) + self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False) + + self.transformer = nn.Sequential(*[ + LinearTransformerBlock( + transformer_dim, + mlp_ratio=mlp_ratio, + attn_drop=attn_drop, + drop=drop, + drop_path=drop_path_rate, + act_layer=layers.act, + norm_layer=transformer_norm_layer + ) + for _ in range(transformer_depth) + ]) + self.norm = transformer_norm_layer(transformer_dim) + + self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1, apply_act=False) + + self.patch_size = to_2tuple(patch_size) + self.patch_area = self.patch_size[0] * self.patch_size[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, H, W = x.shape + patch_h, patch_w = self.patch_size + new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w + num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w + num_patches = num_patch_h * num_patch_w # N + if new_h != H or new_w != W: + x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=True) + + # Local representation + x = self.conv_kxk(x) + x = self.conv_1x1(x) + + # Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N] + C = x.shape[1] + x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4) + x = x.reshape(B, C, -1, num_patches) + + # Global representations + x = self.transformer(x) + x = self.norm(x) + + # Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W] + x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3) + x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w) + + x = self.conv_proj(x) + return x + + +register_block('mobilevit', MobileVitBlock) +register_block('mobilevit2', MobileVitV2Block) def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs): @@ -252,6 +600,14 @@ def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs): **kwargs) +def _create_mobilevit2(variant, cfg_variant=None, pretrained=False, **kwargs): + return build_model_with_cfg( + ByobNet, variant, pretrained, + model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], + feature_cfg=dict(flatten_sequential=True), + **kwargs) + + @register_model def mobilevit_xxs(pretrained=False, **kwargs): return _create_mobilevit('mobilevit_xxs', pretrained=pretrained, **kwargs) @@ -269,4 +625,75 @@ def mobilevit_s(pretrained=False, **kwargs): @register_model def semobilevit_s(pretrained=False, **kwargs): - return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs) \ No newline at end of file + return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_050(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_050', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_075(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_075', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_100(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_100', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_125(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_125', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_150(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_150', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_175(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_175', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_200(pretrained=False, **kwargs): + return _create_mobilevit('mobilevitv2_200', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_150_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_150_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_175_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_175_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_200_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_200_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_150_384_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_150_384_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_175_384_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_175_384_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevitv2_200_384_in22ft1k(pretrained=False, **kwargs): + return _create_mobilevit( + 'mobilevitv2_200_384_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs) \ No newline at end of file diff --git a/timm/models/poolformer.py b/timm/models/poolformer.py index 17d657b0..a95195b4 100644 --- a/timm/models/poolformer.py +++ b/timm/models/poolformer.py @@ -26,7 +26,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, checkpoint_seq -from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp +from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp, GroupNorm1 from .registry import register_model @@ -80,15 +80,6 @@ class PatchEmbed(nn.Module): return x -class GroupNorm1(nn.GroupNorm): - """ Group Normalization with 1 group. - Input: tensor in shape [B, C, H, W] - """ - - def __init__(self, num_channels, **kwargs): - super().__init__(1, num_channels, **kwargs) - - class Pooling(nn.Module): def __init__(self, pool_size=3): super().__init__() From db0cee991028e772c0131e809da1e9e5ea60c568 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 14:43:27 -0700 Subject: [PATCH 17/26] Refactor cspnet configuration using dataclasses, update feature extraction for new cs3 variants. --- timm/models/cspnet.py | 710 ++++++++++++++++++++++++++---------------- 1 file changed, 448 insertions(+), 262 deletions(-) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 4591f101..f0a26baf 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -12,7 +12,10 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage Hacked together by / Copyright 2020 Ross Wightman """ +import collections.abc +from dataclasses import dataclass, field, asdict from functools import partial +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -20,7 +23,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP -from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, get_norm_act_layer +from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, create_act_layer, make_divisible from .registry import register_model @@ -58,218 +61,278 @@ default_cfgs = { ), 'darknetaa53': _cfg(url=''), + 'cs3darknet_s': _cfg( + url=''), 'cs3darknet_m': _cfg( url=''), 'cs3darknet_l': _cfg( url=''), + 'cs3darknet_x': _cfg( + url=''), + + 'cs3darknet_focus_s': _cfg( + url=''), 'cs3darknet_focus_m': _cfg( url=''), 'cs3darknet_focus_l': _cfg( url=''), + 'cs3darknet_focus_x': _cfg( + url=''), + + 'cs3sedarknet_xdw': _cfg( + url=''), } +@dataclass +class CspStemCfg: + out_chs: Union[int, Tuple[int, ...]] = 32 + stride: Union[int, Tuple[int, ...]] = 2 + kernel_size: int = 3 + padding: Union[int, str] = '' + pool: Optional[str] = '' + + +def _pad_arg(x, n): + # pads an argument tuple to specified n by padding with last value + if not isinstance(x, (tuple, list)): + x = (x,) + curr_n = len(x) + pad_n = n - curr_n + if pad_n <= 0: + return x[:n] + return tuple(x + (x[-1],) * pad_n) + + +@dataclass +class CspStagesCfg: + depth: Tuple[int, ...] = (3, 3, 5, 2) # block depth (number of block repeats in stages) + out_chs: Tuple[int, ...] = (128, 256, 512, 1024) # number of output channels for blocks in stage + stride: Union[int, Tuple[int, ...]] = 2 # stride of stage + groups: Union[int, Tuple[int, ...]] = 1 # num kxk conv groups + block_ratio: Union[float, Tuple[float, ...]] = 1.0 + bottle_ratio: Union[float, Tuple[float, ...]] = 1. # bottleneck-ratio of blocks in stage + avg_down: Union[bool, Tuple[bool, ...]] = False + attn_layer: Optional[Union[str, Tuple[str, ...]]] = None + stage_type: Union[str, Tuple[str]] = 'csp' # stage type ('csp', 'cs2', 'dark') + block_type: Union[str, Tuple[str]] = 'bottle' # blocks type for stages ('bottle', 'dark') + + # cross-stage only + expand_ratio: Union[float, Tuple[float, ...]] = 1.0 + cross_linear: Union[bool, Tuple[bool, ...]] = False + down_growth: Union[bool, Tuple[bool, ...]] = False + + def __post_init__(self): + n = len(self.depth) + assert len(self.out_chs) == n + self.stride = _pad_arg(self.stride, n) + self.groups = _pad_arg(self.groups, n) + self.block_ratio = _pad_arg(self.block_ratio, n) + self.bottle_ratio = _pad_arg(self.bottle_ratio, n) + self.avg_down = _pad_arg(self.avg_down, n) + self.attn_layer = _pad_arg(self.attn_layer, n) + self.stage_type = _pad_arg(self.stage_type, n) + self.block_type = _pad_arg(self.block_type, n) + + self.expand_ratio = _pad_arg(self.expand_ratio, n) + self.cross_linear = _pad_arg(self.cross_linear, n) + self.down_growth = _pad_arg(self.down_growth, n) + + +@dataclass +class CspModelCfg: + stem: CspStemCfg + stages: CspStagesCfg + zero_init_last: bool = True # zero init last weight (usually bn) in residual path + act_layer: str = 'relu' + norm_layer: str = 'batchnorm' + aa_layer: Optional[str] = None # FIXME support string factory for this + + +def _cs3darknet_cfg(width_multiplier=1.0, depth_multiplier=1.0, avg_down=False, act_layer='silu', focus=False): + if focus: + stem_cfg = CspStemCfg( + out_chs=make_divisible(64 * width_multiplier), + kernel_size=6, stride=2, padding=2, pool='') + else: + stem_cfg = CspStemCfg( + out_chs=tuple([make_divisible(c * width_multiplier) for c in (32, 64)]), + kernel_size=3, stride=2, pool='') + return CspModelCfg( + stem=stem_cfg, + stages=CspStagesCfg( + out_chs=tuple([make_divisible(c * width_multiplier) for c in (128, 256, 512, 1024)]), + depth=tuple([int(d * depth_multiplier) for d in (3, 6, 9, 3)]), + stride=2, + bottle_ratio=1., + block_ratio=0.5, + avg_down=avg_down, + stage_type='cs3', + block_type='dark', + ), + act_layer=act_layer, + ) + + model_cfgs = dict( - cspresnet50=dict( - stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'), - stage=dict( - out_chs=(128, 256, 512, 1024), + cspresnet50=CspModelCfg( + stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'), + stages=CspStagesCfg( depth=(3, 3, 5, 2), - stride=(1,) + (2,) * 3, - exp_ratio=(2.,) * 4, - bottle_ratio=(0.5,) * 4, - block_ratio=(1.,) * 4, + out_chs=(128, 256, 512, 1024), + stride=(1, 2), + expand_ratio=2., + bottle_ratio=0.5, cross_linear=True, - ) + ), ), - cspresnet50d=dict( - stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'), - stage=dict( - out_chs=(128, 256, 512, 1024), + cspresnet50d=CspModelCfg( + stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'), + stages=CspStagesCfg( depth=(3, 3, 5, 2), - stride=(1,) + (2,) * 3, - exp_ratio=(2.,) * 4, - bottle_ratio=(0.5,) * 4, - block_ratio=(1.,) * 4, + out_chs=(128, 256, 512, 1024), + stride=(1,) + (2,), + expand_ratio=2., + bottle_ratio=0.5, + block_ratio=1., cross_linear=True, ) ), - cspresnet50w=dict( - stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'), - stage=dict( - out_chs=(256, 512, 1024, 2048), + cspresnet50w=CspModelCfg( + stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'), + stages=CspStagesCfg( depth=(3, 3, 5, 2), - stride=(1,) + (2,) * 3, - exp_ratio=(1.,) * 4, - bottle_ratio=(0.25,) * 4, - block_ratio=(0.5,) * 4, + out_chs=(256, 512, 1024, 2048), + stride=(1,) + (2,), + expand_ratio=1., + bottle_ratio=0.25, + block_ratio=0.5, cross_linear=True, ) ), - cspresnext50=dict( - stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'), - stage=dict( - out_chs=(256, 512, 1024, 2048), + cspresnext50=CspModelCfg( + stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'), + stages=CspStagesCfg( depth=(3, 3, 5, 2), - stride=(1,) + (2,) * 3, - groups=(32,) * 4, - exp_ratio=(1.,) * 4, - bottle_ratio=(1.,) * 4, - block_ratio=(0.5,) * 4, + out_chs=(256, 512, 1024, 2048), + stride=(1,) + (2,), + groups=32, + expand_ratio=1., + bottle_ratio=1., + block_ratio=0.5, cross_linear=True, ) ), - cspdarknet53=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( - out_chs=(64, 128, 256, 512, 1024), + cspdarknet53=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( depth=(1, 2, 8, 8, 4), - stride=(2,) * 5, - exp_ratio=(2.,) + (1.,) * 4, - bottle_ratio=(0.5,) + (1.0,) * 4, - block_ratio=(1.,) + (0.5,) * 4, + out_chs=(64, 128, 256, 512, 1024), + stride=2, + expand_ratio=(2.,) + (1.,), + bottle_ratio=(0.5,) + (1.,), + block_ratio=(1.,) + (0.5,), down_growth=True, - ) + block_type='dark', + ), + act_layer='leaky_relu', ), - darknet17=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( - out_chs=(64, 128, 256, 512, 1024), + darknet17=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( depth=(1,) * 5, - stride=(2,) * 5, - bottle_ratio=(0.5,) * 5, - block_ratio=(1.,) * 5, - ) - ), - darknet21=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( out_chs=(64, 128, 256, 512, 1024), - depth=(1, 1, 1, 2, 2), - stride=(2,) * 5, - bottle_ratio=(0.5,) * 5, - block_ratio=(1.,) * 5, - ) + stride=(2,), + bottle_ratio=(0.5,), + block_ratio=(1.,), + stage_type='dark', + block_type='dark', + ), + act_layer='leaky_relu', ), - sedarknet21=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( - out_chs=(64, 128, 256, 512, 1024), + darknet21=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( depth=(1, 1, 1, 2, 2), - stride=(2,) * 5, - bottle_ratio=(0.5,) * 5, - block_ratio=(1.,) * 5, - attn_layer=('se',) * 5, - ) - ), - darknet53=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( out_chs=(64, 128, 256, 512, 1024), - depth=(1, 2, 8, 8, 4), - stride=(2,) * 5, - bottle_ratio=(0.5,) * 5, - block_ratio=(1.,) * 5, - ) - ), + stride=(2,), + bottle_ratio=(0.5,), + block_ratio=(1.,), + stage_type='dark', + block_type='dark', - darknetaa53=dict( - stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), - stage=dict( - out_chs=(64, 128, 256, 512, 1024), - depth=(1, 2, 8, 8, 4), - stride=(2,) * 5, - bottle_ratio=(0.5,) * 5, - block_ratio=(1.,) * 5, - avg_down=True, ), + act_layer='leaky_relu', ), + sedarknet21=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( + depth=(1, 1, 1, 2, 2), + out_chs=(64, 128, 256, 512, 1024), + stride=2, + bottle_ratio=0.5, + block_ratio=1., + attn_layer='se', + stage_type='dark', + block_type='dark', - cs3darknet_m=dict( - stem=dict(out_chs=(24, 48), kernel_size=3, stride=2, pool=''), - stage=dict( - out_chs=(96, 192, 384, 768), - depth=(2, 4, 6, 2), - stride=(2,) * 4, - bottle_ratio=(1.,) * 4, - block_ratio=(0.5,) * 4, - avg_down=False, ), + act_layer='leaky_relu', ), - cs3darknet_l=dict( - stem=dict(out_chs=(32, 64), kernel_size=3, stride=2, pool=''), - stage=dict( - out_chs=(128, 256, 512, 1024), - depth=(3, 6, 9, 3), - stride=(2,) * 4, - bottle_ratio=(1.,) * 4, - block_ratio=(0.5,) * 4, - avg_down=False, + darknet53=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( + depth=(1, 2, 8, 8, 4), + out_chs=(64, 128, 256, 512, 1024), + stride=2, + bottle_ratio=0.5, + block_ratio=1., + stage_type='dark', + block_type='dark', ), + act_layer='leaky_relu', ), - - cs3darknet_focus_m=dict( - stem=dict(out_chs=48, kernel_size=6, stride=2, padding=2, pool=''), - stage=dict( - out_chs=(96, 192, 384, 768), - depth=(2, 4, 6, 2), - stride=(2,) * 4, - bottle_ratio=(1.,) * 4, - block_ratio=(0.5,) * 4, - avg_down=False, + darknetaa53=CspModelCfg( + stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''), + stages=CspStagesCfg( + depth=(1, 2, 8, 8, 4), + out_chs=(64, 128, 256, 512, 1024), + stride=2, + bottle_ratio=0.5, + block_ratio=1., + avg_down=True, + stage_type='dark', + block_type='dark', ), + act_layer='leaky_relu', ), - cs3darknet_focus_l=dict( - stem=dict(out_chs=64, kernel_size=6, stride=2, padding=2, pool=''), - stage=dict( - out_chs=(128, 256, 512, 1024), - depth=(3, 6, 9, 3), - stride=(2,) * 4, - bottle_ratio=(1.,) * 4, - block_ratio=(0.5,) * 4, - avg_down=False, - ), - ) -) + cs3darknet_s=_cs3darknet_cfg(width_multiplier=0.5, depth_multiplier=0.5), + cs3darknet_m=_cs3darknet_cfg(width_multiplier=0.75, depth_multiplier=0.67), + cs3darknet_l=_cs3darknet_cfg(), + cs3darknet_x=_cs3darknet_cfg(width_multiplier=1.25, depth_multiplier=1.33), -def create_stem( - in_chans=3, - out_chs=32, - kernel_size=3, - stride=2, - pool='', - padding='', - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - aa_layer=None -): - stem = nn.Sequential() - if not isinstance(out_chs, (tuple, list)): - out_chs = [out_chs] - assert len(out_chs) - in_c = in_chans - for i, out_c in enumerate(out_chs): - conv_name = f'conv{i + 1}' - stem.add_module(conv_name, ConvNormAct( - in_c, out_c, kernel_size, - stride=stride if i == 0 else 1, - padding=padding if i == 0 else '', - act_layer=act_layer, - norm_layer=norm_layer - )) - in_c = out_c - last_conv = conv_name - if pool: - if aa_layer is not None: - stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1)) - stem.add_module('aa', aa_layer(channels=in_c, stride=2)) - else: - stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) - return stem, dict(num_chs=in_c, reduction=stride, module='.'.join(['stem', last_conv])) + cs3darknet_focus_s=_cs3darknet_cfg(width_multiplier=0.5, depth_multiplier=0.5, focus=True), + cs3darknet_focus_m=_cs3darknet_cfg(width_multiplier=0.75, depth_multiplier=0.67, focus=True), + cs3darknet_focus_l=_cs3darknet_cfg(focus=True), + cs3darknet_focus_x=_cs3darknet_cfg(width_multiplier=1.25, depth_multiplier=1.33, focus=True), + + cs3sedarknet_xdw=CspModelCfg( + stem=CspStemCfg(out_chs=(32, 64), kernel_size=3, stride=2, pool=''), + stages=CspStagesCfg( + depth=(3, 6, 12, 4), + out_chs=(256, 512, 1024, 2048), + stride=2, + groups=(1, 1, 256, 512), + bottle_ratio=0.5, + block_ratio=0.5, + attn_layer='se', + ), + ), +) -class ResBottleneck(nn.Module): +class BottleneckBlock(nn.Module): """ ResNe(X)t Bottleneck Block """ @@ -286,9 +349,9 @@ class ResBottleneck(nn.Module): attn_layer=None, aa_layer=None, drop_block=None, - drop_path=None + drop_path=0. ): - super(ResBottleneck, self).__init__() + super(BottleneckBlock, self).__init__() mid_chs = int(round(out_chs * bottle_ratio)) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) @@ -299,8 +362,8 @@ class ResBottleneck(nn.Module): self.attn2 = create_attn(attn_layer, channels=mid_chs) if not attn_last else None self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs) self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None - self.drop_path = drop_path - self.act3 = act_layer() + self.drop_path = DropPath(drop_path) if drop_path else nn.Identity() + self.act3 = create_act_layer(act_layer) def zero_init_last(self): nn.init.zeros_(self.conv3.bn.weight) @@ -314,9 +377,7 @@ class ResBottleneck(nn.Module): x = self.conv3(x) if self.attn3 is not None: x = self.attn3(x) - if self.drop_path is not None: - x = self.drop_path(x) - x = x + shortcut + x = self.drop_path(x) + shortcut # FIXME partial shortcut needed if first block handled as per original, not used for my current impl #x[:, :shortcut.size(1)] += shortcut x = self.act3(x) @@ -339,7 +400,7 @@ class DarkBlock(nn.Module): attn_layer=None, aa_layer=None, drop_block=None, - drop_path=None + drop_path=0. ): super(DarkBlock, self).__init__() mid_chs = int(round(out_chs * bottle_ratio)) @@ -349,7 +410,7 @@ class DarkBlock(nn.Module): mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups, aa_layer=aa_layer, drop_layer=drop_block, **ckwargs) self.attn = create_attn(attn_layer, channels=out_chs, act_layer=act_layer) - self.drop_path = drop_path + self.drop_path = DropPath(drop_path) if drop_path else nn.Identity() def zero_init_last(self): nn.init.zeros_(self.conv2.bn.weight) @@ -360,9 +421,7 @@ class DarkBlock(nn.Module): x = self.conv2(x) if self.attn is not None: x = self.attn(x) - if self.drop_path is not None: - x = self.drop_path(x) - x = x + shortcut + x = self.drop_path(x) + shortcut return x @@ -377,27 +436,27 @@ class CrossStage(nn.Module): depth, block_ratio=1., bottle_ratio=1., - exp_ratio=1., + expand_ratio=1., groups=1, first_dilation=None, avg_down=False, down_growth=False, cross_linear=False, block_dpr=None, - block_fn=ResBottleneck, + block_fn=BottleneckBlock, **block_kwargs ): super(CrossStage, self).__init__() first_dilation = first_dilation or dilation down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels - self.exp_chs = exp_chs = int(round(out_chs * exp_ratio)) + self.expand_chs = exp_chs = int(round(out_chs * expand_ratio)) block_out_chs = int(round(out_chs * block_ratio)) conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) if stride != 1 or first_dilation != dilation: if avg_down: self.conv_down = nn.Sequential( - nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling + nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ) else: @@ -417,9 +476,15 @@ class CrossStage(nn.Module): self.blocks = nn.Sequential() for i in range(depth): - drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None self.blocks.add_module(str(i), block_fn( - prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + in_chs=prev_chs, + out_chs=block_out_chs, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + drop_path=block_dpr[i] if block_dpr is not None else 0., + **block_kwargs + )) prev_chs = block_out_chs # transition convs @@ -429,7 +494,7 @@ class CrossStage(nn.Module): def forward(self, x): x = self.conv_down(x) x = self.conv_exp(x) - xs, xb = x.split(self.exp_chs // 2, dim=1) + xs, xb = x.split(self.expand_chs // 2, dim=1) xb = self.blocks(xb) xb = self.conv_transition_b(xb).contiguous() out = self.conv_transition(torch.cat([xs, xb], dim=1)) @@ -449,27 +514,27 @@ class CrossStage3(nn.Module): depth, block_ratio=1., bottle_ratio=1., - exp_ratio=1., + expand_ratio=1., groups=1, first_dilation=None, avg_down=False, down_growth=False, cross_linear=False, block_dpr=None, - block_fn=ResBottleneck, + block_fn=BottleneckBlock, **block_kwargs ): super(CrossStage3, self).__init__() first_dilation = first_dilation or dilation down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels - self.exp_chs = exp_chs = int(round(out_chs * exp_ratio)) + self.expand_chs = exp_chs = int(round(out_chs * expand_ratio)) block_out_chs = int(round(out_chs * block_ratio)) conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) if stride != 1 or first_dilation != dilation: if avg_down: self.conv_down = nn.Sequential( - nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling + nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ) else: @@ -487,9 +552,15 @@ class CrossStage3(nn.Module): self.blocks = nn.Sequential() for i in range(depth): - drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None self.blocks.add_module(str(i), block_fn( - prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + in_chs=prev_chs, + out_chs=block_out_chs, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + drop_path=block_dpr[i] if block_dpr is not None else 0., + **block_kwargs + )) prev_chs = block_out_chs # transition convs @@ -498,7 +569,7 @@ class CrossStage3(nn.Module): def forward(self, x): x = self.conv_down(x) x = self.conv_exp(x) - x1, x2 = x.split(self.exp_chs // 2, dim=1) + x1, x2 = x.split(self.expand_chs // 2, dim=1) x1 = self.blocks(x1) out = self.conv_transition(torch.cat([x1, x2], dim=1)) return out @@ -519,7 +590,7 @@ class DarkStage(nn.Module): groups=1, first_dilation=None, avg_down=False, - block_fn=ResBottleneck, + block_fn=BottleneckBlock, block_dpr=None, **block_kwargs ): @@ -529,7 +600,7 @@ class DarkStage(nn.Module): if avg_down: self.conv_down = nn.Sequential( - nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling + nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ) else: @@ -541,9 +612,15 @@ class DarkStage(nn.Module): block_out_chs = int(round(out_chs * block_ratio)) self.blocks = nn.Sequential() for i in range(depth): - drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None self.blocks.add_module(str(i), block_fn( - prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + in_chs=prev_chs, + out_chs=block_out_chs, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + drop_path=block_dpr[i] if block_dpr is not None else 0., + **block_kwargs + )) prev_chs = block_out_chs def forward(self, x): @@ -552,38 +629,131 @@ class DarkStage(nn.Module): return x -def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.): - # get per stage args for stage and containing blocks, calculate strides to meet target output_stride - num_stages = len(cfg['depth']) - if 'groups' not in cfg: - cfg['groups'] = (1,) * num_stages - if 'down_growth' in cfg and not isinstance(cfg['down_growth'], (list, tuple)): - cfg['down_growth'] = (cfg['down_growth'],) * num_stages - if 'cross_linear' in cfg and not isinstance(cfg['cross_linear'], (list, tuple)): - cfg['cross_linear'] = (cfg['cross_linear'],) * num_stages - if 'avg_down' in cfg and not isinstance(cfg['avg_down'], (list, tuple)): - cfg['avg_down'] = (cfg['avg_down'],) * num_stages - cfg['block_dpr'] = [None] * num_stages if not drop_path_rate else \ - [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg['depth'])).split(cfg['depth'])] - stage_strides = [] - stage_dilations = [] - stage_first_dilations = [] +def create_csp_stem( + in_chans=3, + out_chs=32, + kernel_size=3, + stride=2, + pool='', + padding='', + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + aa_layer=None +): + stem = nn.Sequential() + feature_info = [] + if not isinstance(out_chs, (tuple, list)): + out_chs = [out_chs] + stem_depth = len(out_chs) + assert stem_depth + assert stride in (1, 2, 4) + prev_feat = None + prev_chs = in_chans + last_idx = stem_depth - 1 + stem_stride = 1 + for i, chs in enumerate(out_chs): + conv_name = f'conv{i + 1}' + conv_stride = 2 if (i == 0 and stride > 1) or (i == last_idx and stride > 2 and not pool) else 1 + if conv_stride > 1 and prev_feat is not None: + feature_info.append(prev_feat) + stem.add_module(conv_name, ConvNormAct( + prev_chs, chs, kernel_size, + stride=conv_stride, + padding=padding if i == 0 else '', + act_layer=act_layer, + norm_layer=norm_layer + )) + stem_stride *= conv_stride + prev_chs = chs + prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', conv_name])) + if pool: + assert stride > 2 + if prev_feat is not None: + feature_info.append(prev_feat) + if aa_layer is not None: + stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1)) + stem.add_module('aa', aa_layer(channels=prev_chs, stride=2)) + pool_name = 'aa' + else: + stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + pool_name = 'pool' + stem_stride *= 2 + prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', pool_name])) + feature_info.append(prev_feat) + return stem, feature_info + + +def _get_stage_fn(stage_type: str, stage_args): + assert stage_type in ('dark', 'csp', 'cs3') + if stage_type == 'dark': + stage_args.pop('expand_ratio', None) + stage_args.pop('cross_linear', None) + stage_args.pop('down_growth', None) + stage_fn = DarkStage + elif stage_type == 'csp': + stage_fn = CrossStage + else: + stage_fn = CrossStage3 + return stage_fn, stage_args + + +def _get_block_fn(stage_type: str, stage_args): + assert stage_type in ('dark', 'bottle') + if stage_type == 'dark': + return DarkBlock, stage_args + else: + return BottleneckBlock, stage_args + + +def create_csp_stages( + cfg: CspModelCfg, + drop_path_rate: float, + output_stride: int, + stem_feat: Dict[str, Any] +): + cfg_dict = asdict(cfg.stages) + num_stages = len(cfg.stages.depth) + cfg_dict['block_dpr'] = [None] * num_stages if not drop_path_rate else \ + [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.stages.depth)).split(cfg.stages.depth)] + stage_args = [dict(zip(cfg_dict.keys(), values)) for values in zip(*cfg_dict.values())] + block_kwargs = dict( + act_layer=cfg.act_layer, + norm_layer=cfg.norm_layer, + aa_layer=cfg.aa_layer + ) + dilation = 1 - for cfg_stride in cfg['stride']: - stage_first_dilations.append(dilation) - if curr_stride >= output_stride: - dilation *= cfg_stride + net_stride = stem_feat['reduction'] + prev_chs = stem_feat['num_chs'] + prev_feat = stem_feat + feature_info = [] + stages = [] + for stage_idx, stage_args in enumerate(stage_args): + stage_fn, stage_args = _get_stage_fn(stage_args.pop('stage_type'), stage_args) + block_fn, stage_args = _get_block_fn(stage_args.pop('block_type'), stage_args) + stride = stage_args.pop('stride') + if stride != 1 and prev_feat: + feature_info.append(prev_feat) + if net_stride >= output_stride and stride > 1: + dilation *= stride stride = 1 - else: - stride = cfg_stride - curr_stride *= stride - stage_strides.append(stride) - stage_dilations.append(dilation) - cfg['stride'] = stage_strides - cfg['dilation'] = stage_dilations - cfg['first_dilation'] = stage_first_dilations - stage_args = [dict(zip(cfg.keys(), values)) for values in zip(*cfg.values())] - return stage_args + net_stride *= stride + first_dilation = 1 if dilation in (1, 2) else 2 + + stages += [stage_fn( + prev_chs, + **stage_args, + stride=stride, + first_dilation=first_dilation, + dilation=dilation, + block_fn=block_fn, + **block_kwargs, + )] + prev_chs = stage_args['out_chs'] + prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}') + + feature_info.append(prev_feat) + return nn.Sequential(*stages), feature_info class CspNet(nn.Module): @@ -598,43 +768,39 @@ class CspNet(nn.Module): def __init__( self, - cfg, + cfg: CspModelCfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', - act_layer=nn.LeakyReLU, - norm_layer=nn.BatchNorm2d, - aa_layer=None, drop_rate=0., drop_path_rate=0., - zero_init_last=True, - stage_fn=CrossStage, - block_fn=ResBottleneck): + zero_init_last=True + ): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate assert output_stride in (8, 16, 32) - layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) + layer_args = dict( + act_layer=cfg.act_layer, + norm_layer=cfg.norm_layer, + aa_layer=cfg.aa_layer + ) + self.feature_info = [] # Construct the stem - self.stem, stem_feat_info = create_stem(in_chans, **cfg['stem'], **layer_args) - self.feature_info = [stem_feat_info] - prev_chs = stem_feat_info['num_chs'] - curr_stride = stem_feat_info['reduction'] # reduction does not include pool - if cfg['stem']['pool']: - curr_stride *= 2 + self.stem, stem_feat_info = create_csp_stem(in_chans, **asdict(cfg.stem), **layer_args) + self.feature_info.extend(stem_feat_info[:-1]) # Construct the stages - per_stage_args = _cfg_to_stage_args( - cfg['stage'], curr_stride=curr_stride, output_stride=output_stride, drop_path_rate=drop_path_rate) - self.stages = nn.Sequential() - for i, sa in enumerate(per_stage_args): - self.stages.add_module( - str(i), stage_fn(prev_chs, **sa, **layer_args, block_fn=block_fn)) - prev_chs = sa['out_chs'] - curr_stride *= sa['stride'] - self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] + self.stages, stage_feat_info = create_csp_stages( + cfg, + drop_path_rate=drop_path_rate, + output_stride=output_stride, + stem_feat=stem_feat_info[-1], + ) + prev_chs = stage_feat_info[-1]['num_chs'] + self.feature_info.extend(stage_feat_info) # Construct the head self.num_features = prev_chs @@ -729,54 +895,74 @@ def cspresnext50(pretrained=False, **kwargs): @register_model def cspdarknet53(pretrained=False, **kwargs): - return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs) + return _create_cspnet('cspdarknet53', pretrained=pretrained, **kwargs) @register_model def darknet17(pretrained=False, **kwargs): - return _create_cspnet('darknet17', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + return _create_cspnet('darknet17', pretrained=pretrained, **kwargs) @register_model def darknet21(pretrained=False, **kwargs): - return _create_cspnet('darknet21', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + return _create_cspnet('darknet21', pretrained=pretrained, **kwargs) @register_model def sedarknet21(pretrained=False, **kwargs): - return _create_cspnet('sedarknet21', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + return _create_cspnet('sedarknet21', pretrained=pretrained, **kwargs) @register_model def darknet53(pretrained=False, **kwargs): - return _create_cspnet('darknet53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + return _create_cspnet('darknet53', pretrained=pretrained, **kwargs) @register_model def darknetaa53(pretrained=False, **kwargs): - return _create_cspnet( - 'darknetaa53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + return _create_cspnet('darknetaa53', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_s(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_s', pretrained=pretrained, **kwargs) @register_model def cs3darknet_m(pretrained=False, **kwargs): - return _create_cspnet( - 'cs3darknet_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) + return _create_cspnet('cs3darknet_m', pretrained=pretrained, **kwargs) @register_model def cs3darknet_l(pretrained=False, **kwargs): - return _create_cspnet( - 'cs3darknet_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) + return _create_cspnet('cs3darknet_l', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_x(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_x', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_focus_s(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_focus_s', pretrained=pretrained, **kwargs) @register_model def cs3darknet_focus_m(pretrained=False, **kwargs): - return _create_cspnet( - 'cs3darknet_focus_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) + return _create_cspnet('cs3darknet_focus_m', pretrained=pretrained, **kwargs) @register_model def cs3darknet_focus_l(pretrained=False, **kwargs): - return _create_cspnet( - 'cs3darknet_focus_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage3, act_layer='silu', **kwargs) \ No newline at end of file + return _create_cspnet('cs3darknet_focus_l', pretrained=pretrained, **kwargs) + + +@register_model +def cs3darknet_focus_x(pretrained=False, **kwargs): + return _create_cspnet('cs3darknet_focus_x', pretrained=pretrained, **kwargs) + + +@register_model +def cs3sedarknet_xdw(pretrained=False, **kwargs): + return _create_cspnet('cs3sedarknet_xdw', pretrained=pretrained, **kwargs) From 28e01520434930b420aedfc979f0f2fcee513628 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 15:13:06 -0700 Subject: [PATCH 18/26] Add --no-retry flag to benchmark.py to skip batch_size decay and retry on error. Fix #1226. Update deepspeed profile usage for latest DS releases. Fix # 1333 --- benchmark.py | 53 ++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/benchmark.py b/benchmark.py index 1362eeab..74f09489 100755 --- a/benchmark.py +++ b/benchmark.py @@ -71,6 +71,8 @@ parser.add_argument('--bench', default='both', type=str, help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'") parser.add_argument('--detail', action='store_true', default=False, help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False') +parser.add_argument('--no-retry', action='store_true', default=False, + help='Do not decay batch size and retry on error.') parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', help='Output csv file for validation results (summary)') parser.add_argument('--num-warm-iter', default=10, type=int, @@ -169,10 +171,9 @@ def resolve_precision(precision: str): def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False): - macs, _ = get_model_profile( + _, macs, _ = get_model_profile( model=model, - input_res=(batch_size,) + input_size, # input shape or input to the input_constructor - input_constructor=None, # if specified, a constructor taking input_res is used as input to the model + input_shape=(batch_size,) + input_size, # input shape/resolution print_profile=detailed, # prints the model graph with the measured profile attached to each module detailed=detailed, # print the detailed profile warm_up=10, # the number of warm-ups before measuring the time of each module @@ -197,8 +198,19 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False class BenchmarkRunner: def __init__( - self, model_name, detail=False, device='cuda', torchscript=False, aot_autograd=False, precision='float32', - fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs): + self, + model_name, + detail=False, + device='cuda', + torchscript=False, + aot_autograd=False, + precision='float32', + fuser='', + num_warm_iter=10, + num_bench_iter=50, + use_train_size=False, + **kwargs + ): self.model_name = model_name self.detail = detail self.device = device @@ -256,7 +268,13 @@ class BenchmarkRunner: class InferenceBenchmarkRunner(BenchmarkRunner): - def __init__(self, model_name, device='cuda', torchscript=False, **kwargs): + def __init__( + self, + model_name, + device='cuda', + torchscript=False, + **kwargs + ): super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs) self.model.eval() @@ -325,7 +343,13 @@ class InferenceBenchmarkRunner(BenchmarkRunner): class TrainBenchmarkRunner(BenchmarkRunner): - def __init__(self, model_name, device='cuda', torchscript=False, **kwargs): + def __init__( + self, + model_name, + device='cuda', + torchscript=False, + **kwargs + ): super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs) self.model.train() @@ -492,7 +516,7 @@ def decay_batch_exp(batch_size, factor=0.5, divisor=16): return max(0, int(out_batch_size)) -def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs): +def _try_run(model_name, bench_fn, bench_kwargs, initial_batch_size, no_batch_size_retry=False): batch_size = initial_batch_size results = dict() error_str = 'Unknown' @@ -507,8 +531,11 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs): if 'channels_last' in error_str: _logger.error(f'{model_name} not supported in channels_last, skipping.') break - _logger.warning(f'"{error_str}" while running benchmark. Reducing batch size to {batch_size} for retry.') + _logger.error(f'"{error_str}" while running benchmark.') + if no_batch_size_retry: + break batch_size = decay_batch_exp(batch_size) + _logger.warning(f'Reducing batch size to {batch_size} for retry.') results['error'] = error_str return results @@ -550,7 +577,13 @@ def benchmark(args): model_results = OrderedDict(model=model) for prefix, bench_fn in zip(prefixes, bench_fns): - run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs) + run_results = _try_run( + model, + bench_fn, + bench_kwargs=bench_kwargs, + initial_batch_size=batch_size, + no_batch_size_retry=args.no_retry, + ) if prefix and 'error' not in run_results: run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()} model_results.update(run_results) From 500c190860bb80da348dd719dd8f0b73e44f0854 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 15:15:25 -0700 Subject: [PATCH 19/26] Add --aot-autograd (functorch efficient mem fusion) support to validate.py --- validate.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/validate.py b/validate.py index 27b88299..708ac2e5 100755 --- a/validate.py +++ b/validate.py @@ -38,6 +38,12 @@ try: except AttributeError: pass +try: + from functorch.compile import memory_efficient_fusion + has_functorch = True +except ImportError as e: + has_functorch = False + torch.backends.cudnn.benchmark = True _logger = logging.getLogger('validate') @@ -101,8 +107,11 @@ parser.add_argument('--tf-preprocessing', action='store_true', default=False, help='Use Tensorflow preprocessing pipeline (require CPU TF installed') parser.add_argument('--use-ema', dest='use_ema', action='store_true', help='use ema version of weights if present') -parser.add_argument('--torchscript', dest='torchscript', action='store_true', - help='convert model torchscript for inference') +scripting_group = parser.add_mutually_exclusive_group() +scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', + help='torch.jit.script the full model') +scripting_group.add_argument('--aot-autograd', default=False, action='store_true', + help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', @@ -162,7 +171,10 @@ def validate(args): if args.torchscript: torch.jit.optimized_execution(True) - model = torch.jit.script(model) + model = torch.jit.trace(model, example_inputs=torch.randn((args.batch_size,) + data_config['input_size'])) + if args.aot_autograd: + assert has_functorch, "functorch is needed for --aot-autograd" + model = memory_efficient_fusion(model) model = model.cuda() if args.apex_amp: From 4670d375c6ac37457094fa5519079936328d1e67 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 15:21:29 -0700 Subject: [PATCH 20/26] Reorg benchmark.py import --- benchmark.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/benchmark.py b/benchmark.py index 74f09489..23047bb5 100755 --- a/benchmark.py +++ b/benchmark.py @@ -6,24 +6,23 @@ An inference and train step benchmark script for timm models. Hacked together by Ross Wightman (https://github.com/rwightman) """ import argparse -import os import csv import json -import time import logging -import torch -import torch.nn as nn -import torch.nn.parallel +import time from collections import OrderedDict from contextlib import suppress from functools import partial +import torch +import torch.nn as nn +import torch.nn.parallel + +from timm.data import resolve_data_config from timm.models import create_model, is_model, list_models from timm.optim import create_optimizer_v2 -from timm.data import resolve_data_config from timm.utils import setup_default_logging, set_jit_fuser - has_apex = False try: from apex import amp From 9be0c847154b7a20cfd3f51fc0c366099c007f5c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 15:33:53 -0700 Subject: [PATCH 21/26] Change set -> dict w/ None keys for dataset split synonym search, so always consistent if more than 1 exists. Fix #1224 --- timm/data/dataset_factory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index 194a597e..d0ac30b1 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -26,8 +26,8 @@ _TORCH_BASIC_DS = dict( kmnist=KMNIST, fashion_mnist=FashionMNIST, ) -_TRAIN_SYNONYM = {'train', 'training'} -_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'} +_TRAIN_SYNONYM = dict(train=None, training=None) +_EVAL_SYNONYM = dict(val=None, valid=None, validation=None, eval=None, evaluation=None) def _search_split(root, split): From 58621723bda1fe386e8eebd729e743e255f992eb Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 17:43:38 -0700 Subject: [PATCH 22/26] Add CrossStage3 DarkNet (cs3) weights --- timm/models/cspnet.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index f0a26baf..e8e8910e 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -57,30 +57,35 @@ default_cfgs = { 'sedarknet21': _cfg(url=''), 'darknet53': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknet53_256_c2ns-3aeff817.pth', - test_input_size=(3, 288, 288), test_crop_pct=1.0, interpolation='bicubic' + interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0, ), 'darknetaa53': _cfg(url=''), 'cs3darknet_s': _cfg( - url=''), + url='', interpolation='bicubic'), 'cs3darknet_m': _cfg( - url=''), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_m_c2ns-43f06604.pth', + interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95, + ), 'cs3darknet_l': _cfg( - url=''), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_l_c2ns-16220c5d.pth', + interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95), 'cs3darknet_x': _cfg( url=''), 'cs3darknet_focus_s': _cfg( - url=''), + url='', interpolation='bicubic'), 'cs3darknet_focus_m': _cfg( - url=''), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_m_c2ns-e23bed41.pth', + interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95), 'cs3darknet_focus_l': _cfg( - url=''), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_l_c2ns-65ef8888.pth', + interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95), 'cs3darknet_focus_x': _cfg( - url=''), + url='', interpolation='bicubic'), 'cs3sedarknet_xdw': _cfg( - url=''), + url='', interpolation='bicubic'), } From ce65a7b29fa39c3f4d09b03b515b2af31ec9aea5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 21:33:25 -0700 Subject: [PATCH 23/26] Update vit_relpos w/ some additional weights, some cleanup to match recent vit updates, more MLP log coord experiments. --- timm/models/vision_transformer_relpos.py | 194 +++++++++++++++++------ 1 file changed, 145 insertions(+), 49 deletions(-) diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 0c9ac989..52b3ce45 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -8,6 +8,7 @@ import math import logging from functools import partial from collections import OrderedDict +from dataclasses import dataclass from typing import Optional, Tuple import torch @@ -16,7 +17,7 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, named_apply +from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple from .registry import register_model @@ -47,9 +48,16 @@ default_cfgs = { 'vit_relpos_base_patch16_224': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'), + 'vit_srelpos_small_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_small_patch16_224-sw-6cdb8849.pth'), + 'vit_srelpos_medium_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_medium_patch16_224-sw-ad702b8c.pth'), + + 'vit_relpos_medium_patch16_cls_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_cls_224-sw-cfe8e259.pth'), 'vit_relpos_base_patch16_cls_224': _cfg( url=''), - 'vit_relpos_base_patch16_gapcls_224': _cfg( + 'vit_relpos_base_patch16_clsgap_224': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'), 'vit_relpos_small_patch16_rpn_224': _cfg(url=''), @@ -59,35 +67,43 @@ default_cfgs = { } -def gen_relative_position_index(win_size: Tuple[int, int], class_token: int = 0) -> torch.Tensor: - # cut and paste w/ modifications from swin / beit codebase - # cls to token & token 2 cls & cls to cls +def gen_relative_position_index( + q_size: Tuple[int, int], + k_size: Tuple[int, int] = None, + class_token: bool = False) -> torch.Tensor: + # Adapted with significant modifications from Swin / BeiT codebases # get pair-wise relative position index for each token inside the window - window_area = win_size[0] * win_size[1] - coords = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1) # 2, Wh, Ww - relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += win_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += win_size[1] - 1 - relative_coords[:, :, 0] *= 2 * win_size[1] - 1 + q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww + if k_size is None: + k_coords = q_coords + k_size = q_size + else: + # different q vs k sizes is a WIP + k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1) + relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2 + _, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0) + if class_token: - num_relative_distance = (2 * win_size[0] - 1) * (2 * win_size[1] - 1) + 3 - relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype) - relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + # handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias + # NOTE not intended or tested with MLP log-coords + max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1])) + num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3 + relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0]) relative_position_index[0, 0:] = num_relative_distance - 3 relative_position_index[0:, 0] = num_relative_distance - 2 relative_position_index[0, 0] = num_relative_distance - 1 - else: - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - return relative_position_index + + return relative_position_index.contiguous() def gen_relative_log_coords( win_size: Tuple[int, int], pretrained_win_size: Tuple[int, int] = (0, 0), - mode='swin' + mode='swin', ): - # as per official swin-v2 impl, supporting timm swin-v2-cr coords as well + assert mode in ('swin', 'cr', 'rw') + # as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32) relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32) relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) @@ -100,12 +116,22 @@ def gen_relative_log_coords( relative_coords_table[:, :, 0] /= (win_size[0] - 1) relative_coords_table[:, :, 1] /= (win_size[1] - 1) relative_coords_table *= 8 # normalize to -8, 8 - scale = math.log2(8) + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + 1.0 + relative_coords_table.abs()) / math.log2(8) else: - # FIXME we should support a form of normalization (to -1/1) for this mode? - scale = math.log2(math.e) - relative_coords_table = torch.sign(relative_coords_table) * torch.log2( - 1.0 + relative_coords_table.abs()) / scale + if mode == 'rw': + # cr w/ window size normalization -> [-1,1] log coords + relative_coords_table[:, :, 0] /= (win_size[0] - 1) + relative_coords_table[:, :, 1] /= (win_size[1] - 1) + relative_coords_table *= 8 # scale to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + 1.0 + relative_coords_table.abs()) + relative_coords_table /= math.log2(9) # -> [-1, 1] + else: + # mode == 'cr' + relative_coords_table = torch.sign(relative_coords_table) * torch.log( + 1.0 + relative_coords_table.abs()) + return relative_coords_table @@ -115,19 +141,29 @@ class RelPosMlp(nn.Module): window_size, num_heads=8, hidden_dim=128, - class_token=False, + prefix_tokens=0, mode='cr', pretrained_window_size=(0, 0) ): super().__init__() self.window_size = window_size self.window_area = self.window_size[0] * self.window_size[1] - self.class_token = 1 if class_token else 0 + self.prefix_tokens = prefix_tokens self.num_heads = num_heads self.bias_shape = (self.window_area,) * 2 + (num_heads,) - self.apply_sigmoid = mode == 'swin' + if mode == 'swin': + self.bias_act = nn.Sigmoid() + self.bias_gain = 16 + mlp_bias = (True, False) + elif mode == 'rw': + self.bias_act = nn.Tanh() + self.bias_gain = 4 + mlp_bias = True + else: + self.bias_act = nn.Identity() + self.bias_gain = None + mlp_bias = True - mlp_bias = (True, False) if mode == 'swin' else True self.mlp = Mlp( 2, # x, y hidden_features=hidden_dim, @@ -155,10 +191,11 @@ class RelPosMlp(nn.Module): self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.view(self.bias_shape) relative_position_bias = relative_position_bias.permute(2, 0, 1) - if self.apply_sigmoid: - relative_position_bias = 16 * torch.sigmoid(relative_position_bias) - if self.class_token: - relative_position_bias = F.pad(relative_position_bias, [self.class_token, 0, self.class_token, 0]) + relative_position_bias = self.bias_act(relative_position_bias) + if self.bias_gain is not None: + relative_position_bias = self.bias_gain * relative_position_bias + if self.prefix_tokens: + relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0]) return relative_position_bias.unsqueeze(0).contiguous() def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): @@ -167,18 +204,18 @@ class RelPosMlp(nn.Module): class RelPosBias(nn.Module): - def __init__(self, window_size, num_heads, class_token=False): + def __init__(self, window_size, num_heads, prefix_tokens=0): super().__init__() + assert prefix_tokens <= 1 self.window_size = window_size self.window_area = window_size[0] * window_size[1] - self.class_token = 1 if class_token else 0 - self.bias_shape = (self.window_area + self.class_token,) * 2 + (num_heads,) + self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,) - num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * self.class_token + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads)) self.register_buffer( "relative_position_index", - gen_relative_position_index(self.window_size, class_token=self.class_token), + gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0), persistent=False, ) @@ -306,11 +343,32 @@ class VisionTransformerRelPos(nn.Module): """ def __init__( - self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg', - embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-6, - class_token=False, fc_norm=False, rel_pos_type='mlp', shared_rel_pos=False, rel_pos_dim=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', - embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock): + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + global_pool='avg', + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + init_values=1e-6, + class_token=False, + fc_norm=False, + rel_pos_type='mlp', + rel_pos_dim=None, + shared_rel_pos=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + weight_init='skip', + embed_layer=PatchEmbed, + norm_layer=None, + act_layer=None, + block_fn=RelPosBlock + ): """ Args: img_size (int, tuple): input image size @@ -345,19 +403,22 @@ class VisionTransformerRelPos(nn.Module): self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 1 if class_token else 0 + self.num_prefix_tokens = 1 if class_token else 0 self.grad_checkpointing = False self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) feat_size = self.patch_embed.grid_size - rel_pos_args = dict(window_size=feat_size, class_token=class_token) + rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens) if rel_pos_type.startswith('mlp'): if rel_pos_dim: rel_pos_args['hidden_dim'] = rel_pos_dim + # FIXME experimenting with different relpos log coord configs if 'swin' in rel_pos_type: rel_pos_args['mode'] = 'swin' + elif 'rw' in rel_pos_type: + rel_pos_args['mode'] = 'rw' rel_pos_cls = partial(RelPosMlp, **rel_pos_args) else: rel_pos_cls = partial(RelPosBias, **rel_pos_args) @@ -367,7 +428,7 @@ class VisionTransformerRelPos(nn.Module): # NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both... rel_pos_cls = None - self.cls_token = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) if self.num_tokens else None + self.cls_token = nn.Parameter(torch.zeros(1, self.num_prefix_tokens, embed_dim)) if class_token else None dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ @@ -434,7 +495,7 @@ class VisionTransformerRelPos(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool: - x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) return x if pre_logits else self.head(x) @@ -502,6 +563,41 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs): return model +@register_model +def vit_srelpos_small_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, fc_norm=False, + rel_pos_dim=384, shared_rel_pos=True, **kwargs) + model = _create_vision_transformer_relpos('vit_srelpos_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_srelpos_medium_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=False, + rel_pos_dim=512, shared_rel_pos=True, **kwargs) + model = _create_vision_transformer_relpos( + 'vit_srelpos_medium_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_medium_patch16_cls_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-M/16) w/ relative log-coord position, class token present + """ + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=False, + rel_pos_dim=256, class_token=True, global_pool='token', **kwargs) + model = _create_vision_transformer_relpos( + 'vit_relpos_medium_patch16_cls_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present @@ -514,14 +610,14 @@ def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs): @register_model -def vit_relpos_base_patch16_gapcls_224(pretrained=False, **kwargs): +def vit_relpos_base_patch16_clsgap_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present NOTE this config is a bit of a mistake, class token was enabled but global avg-pool w/ fc-norm was not disabled Leaving here for comparisons w/ a future re-train as it performs quite well. """ model_kwargs = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, class_token=True, **kwargs) - model = _create_vision_transformer_relpos('vit_relpos_base_patch16_gapcls_224', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer_relpos('vit_relpos_base_patch16_clsgap_224', pretrained=pretrained, **model_kwargs) return model From 7c7ecd24923b19338ca083d56369193e153294f0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 22:01:24 -0700 Subject: [PATCH 24/26] Add --use-train-size flag to force use of train input_size (over test input size) for validation. Default test-time pooling to use train input size (fixes issues). --- timm/models/layers/test_time_pool.py | 2 +- validate.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/timm/models/layers/test_time_pool.py b/timm/models/layers/test_time_pool.py index 98c0bf53..5826d8c9 100644 --- a/timm/models/layers/test_time_pool.py +++ b/timm/models/layers/test_time_pool.py @@ -36,7 +36,7 @@ class TestTimePoolHead(nn.Module): return x.view(x.size(0), -1) -def apply_test_time_pool(model, config, use_test_size=True): +def apply_test_time_pool(model, config, use_test_size=False): test_time_pool = False if not hasattr(model, 'default_cfg') or not model.default_cfg: return model, False diff --git a/validate.py b/validate.py index 708ac2e5..7fa22b49 100755 --- a/validate.py +++ b/validate.py @@ -67,6 +67,8 @@ parser.add_argument('--img-size', default=None, type=int, metavar='N', help='Input image dimension, uses model default if empty') parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') +parser.add_argument('--use-train-size', action='store_true', default=False, + help='force use of train input size, even when test size is specified in pretrained cfg') parser.add_argument('--crop-pct', default=None, type=float, metavar='N', help='Input image center crop pct') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', @@ -164,10 +166,15 @@ def validate(args): param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) - data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True) + data_config = resolve_data_config( + vars(args), + model=model, + use_test_size=not args.use_train_size, + verbose=True + ) test_time_pool = False if args.test_pool: - model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) + model, test_time_pool = apply_test_time_pool(model, data_config) if args.torchscript: torch.jit.optimized_execution(True) From a1cb25066e26c8bf8fa410987b808899657c9e20 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 22:02:57 -0700 Subject: [PATCH 25/26] Add edgnext_small_rw weights trained with swin like recipe. Better than original 'small' but not the recent 'USI' distilled weights. --- timm/models/edgenext.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index 97971ba6..29316b9a 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -46,10 +46,13 @@ default_cfgs = dict( # url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_small.pth"), edgenext_small=_cfg( # USI weights url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth", - crop_pct=0.95 + crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, ), - edgenext_small_rw=_cfg(), + edgenext_small_rw=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth', + test_input_size=(3, 320, 320), test_crop_pct=1.0, + ), ) From 1c5cb819f94834b73843e1f088a3b2b12f550680 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 22:05:56 -0700 Subject: [PATCH 26/26] bump version to 0.6.3 before merge --- timm/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/version.py b/timm/version.py index 3e8e43bd..7165c7fa 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.6.2.dev0' +__version__ = '0.6.3.dev0'