From 58ba49c8ef4cd56b15e32b1eb17b268ed7289200 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 31 Jan 2022 15:38:32 -0800 Subject: [PATCH 1/4] Add MobileViT models (w/ ByobNet base). Close #1038. --- timm/models/__init__.py | 1 + timm/models/mobilevit.py | 248 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 249 insertions(+) create mode 100644 timm/models/mobilevit.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 44e31f36..35209a2b 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -24,6 +24,7 @@ from .inception_v4 import * from .levit import * from .mlp_mixer import * from .mobilenetv3 import * +from .mobilevit import * from .nasnet import * from .nest import * from .nfnet import * diff --git a/timm/models/mobilevit.py b/timm/models/mobilevit.py new file mode 100644 index 00000000..1cf519c2 --- /dev/null +++ b/timm/models/mobilevit.py @@ -0,0 +1,248 @@ +""" MobileViT + +Paper: +`MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178 + +MobileVitBlock and checkpoints adapted from https://github.com/apple/ml-cvnets (original copyright below) +License: https://github.com/apple/ml-cvnets/blob/main/LICENSE (Apple open source) + +Rest of code, ByobNet, and Transformer block hacked together by / Copyright 2022, Ross Wightman +""" +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2020 Apple Inc. All Rights Reserved. +# +import math +from typing import Union, Callable, Dict, Tuple, Optional + +import torch +from torch import nn +import torch.nn.functional as F + +from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups +from .layers import to_2tuple, make_divisible +from .vision_transformer import Block as TransformerBlock +from .helpers import build_model_with_cfg +from .registry import register_model + +__all__ = [] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), + 'crop_pct': 0.9, 'interpolation': 'bicubic', + 'mean': (0, 0, 0), 'std': (1, 1, 1), + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + 'fixed_input_size': False, 'min_input_size': (3, 256, 256), + **kwargs + } + + +default_cfgs = { + # GPU-Efficient (ResNet) weights + 'mobilevit_xxs': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_xxs-ad385b40.pth'), + 'mobilevit_xs': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_xs-8fbd6366.pth'), + 'mobilevit_s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_s-38a5a959.pth'), +} + + +def _inverted_residual_block(d, c, s, br=4.0): + # inverted residual is a bottleneck block with bottle_ratio > 1 applied to in_chs, linear output, gs=1 (depthwise) + return ByoBlockCfg( + type='bottle', d=d, c=c, s=s, gs=1, br=br, + block_kwargs=dict(bottle_in=True, linear_out=True)) + + +def _mobilevit_block(d, c, s, transformer_dim, transformer_depth, patch_size=4, br=4.0): + # inverted residual + mobilevit blocks as per MobileViT network + return ( + _inverted_residual_block(d=d, c=c, s=s, br=br), + ByoBlockCfg( + type='mobilevit', d=1, c=c, s=1, + block_kwargs=dict( + transformer_dim=transformer_dim, + transformer_depth=transformer_depth, + patch_size=patch_size) + ) + ) + + +model_cfgs = dict( + mobilevit_xxs=ByoModelCfg( + blocks=( + _inverted_residual_block(d=1, c=16, s=1, br=2.0), + _inverted_residual_block(d=3, c=24, s=2, br=2.0), + _mobilevit_block(d=1, c=48, s=2, transformer_dim=64, transformer_depth=2, patch_size=2, br=2.0), + _mobilevit_block(d=1, c=64, s=2, transformer_dim=80, transformer_depth=4, patch_size=2, br=2.0), + _mobilevit_block(d=1, c=80, s=2, transformer_dim=96, transformer_depth=3, patch_size=2, br=2.0), + ), + stem_chs=16, + stem_type='3x3', + stem_pool='', + downsample='', + act_layer='silu', + num_features=320, + ), + + mobilevit_xs=ByoModelCfg( + blocks=( + _inverted_residual_block(d=1, c=32, s=1), + _inverted_residual_block(d=3, c=48, s=2), + _mobilevit_block(d=1, c=64, s=2, transformer_dim=96, transformer_depth=2, patch_size=2), + _mobilevit_block(d=1, c=80, s=2, transformer_dim=120, transformer_depth=4, patch_size=2), + _mobilevit_block(d=1, c=96, s=2, transformer_dim=144, transformer_depth=3, patch_size=2), + ), + stem_chs=16, + stem_type='3x3', + stem_pool='', + downsample='', + act_layer='silu', + num_features=384, + ), + + mobilevit_s=ByoModelCfg( + blocks=( + _inverted_residual_block(d=1, c=32, s=1), + _inverted_residual_block(d=3, c=64, s=2), + _mobilevit_block(d=1, c=96, s=2, transformer_dim=144, transformer_depth=2, patch_size=2), + _mobilevit_block(d=1, c=128, s=2, transformer_dim=192, transformer_depth=4, patch_size=2), + _mobilevit_block(d=1, c=160, s=2, transformer_dim=240, transformer_depth=3, patch_size=2), + ), + stem_chs=16, + stem_type='3x3', + stem_pool='', + downsample='', + act_layer='silu', + num_features=640, + ), +) + + +class MobileViTBlock(nn.Module): + """ MobileViT block + Paper: https://arxiv.org/abs/2110.02178?context=cs.LG + """ + def __init__( + self, + in_chs: int, + out_chs: Optional[int] = None, + kernel_size: int = 3, + stride: int = 1, + bottle_ratio: float = 1.0, + group_size: Optional[int] = None, + dilation: Tuple[int, int] = (1, 1), + mlp_ratio: float = 2.0, + transformer_dim: Optional[int] = None, + transformer_depth: int = 2, + patch_size: int = 8, + num_heads: int = 4, + attn_drop: float = 0., + drop: int = 0., + no_fusion: bool = False, + drop_path_rate: float = 0., + layers: LayerFn = None, + transformer_norm_layer: Callable = nn.LayerNorm, + downsample: str = '' + ): + super(MobileViTBlock, self).__init__() + + layers = layers or LayerFn() + groups = num_groups(group_size, in_chs) + out_chs = out_chs or in_chs + transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs) + + self.conv_kxk = layers.conv_norm_act( + in_chs, in_chs, kernel_size=kernel_size, + stride=stride, groups=groups, dilation=dilation[0]) + self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False) + + self.transformer = nn.Sequential(*[ + TransformerBlock( + transformer_dim, mlp_ratio=mlp_ratio, num_heads=num_heads, qkv_bias=True, + attn_drop=attn_drop, drop=drop, drop_path=drop_path_rate, + act_layer=layers.act, norm_layer=transformer_norm_layer) + for _ in range(transformer_depth) + ]) + self.norm = transformer_norm_layer(transformer_dim) + + self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1) + + if no_fusion: + self.conv_fusion = None + else: + self.conv_fusion = layers.conv_norm_act(in_chs + out_chs, out_chs, kernel_size=kernel_size, stride=1) + + self.patch_size = to_2tuple(patch_size) + self.patch_area = self.patch_size[0] * self.patch_size[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + + # Local representation + x = self.conv_kxk(x) + x = self.conv_1x1(x) + + # Unfold (feature map -> patches) + patch_h, patch_w = self.patch_size + B, C, H, W = x.shape + new_h, new_w = int(math.ceil(H / patch_h) * patch_h), int(math.ceil(W / patch_w) * patch_w) + num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w + num_patches = num_patch_h * num_patch_w # N + interpolate = False + if new_h != H or new_w != W: + # Note: Padding can be done, but then it needs to be handled in attention function. + x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False) + interpolate = True + + # [B, C, H, W] --> [B * C * n_h, n_w, p_h, p_w] + x = x.reshape(B * C * num_patch_h, patch_h, num_patch_w, patch_w).transpose(1, 2) + # [B * C * n_h, n_w, p_h, p_w] --> [BP, N, C] where P = p_h * p_w and N = n_h * n_w + x = x.reshape(B, C, num_patches, self.patch_area).transpose(1, 3).reshape(B * self.patch_area, num_patches, -1) + + # Global representations + x = self.transformer(x) + x = self.norm(x) + + # Fold (patch -> feature map) + # [B, P, N, C] --> [B*C*n_h, n_w, p_h, p_w] + x = x.contiguous().view(B, self.patch_area, num_patches, -1) + x = x.transpose(1, 3).reshape(B * C * num_patch_h, num_patch_w, patch_h, patch_w) + # [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W] + x = x.transpose(1, 2).reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w) + if interpolate: + x = F.interpolate(x, size=(H, W), mode="bilinear", align_corners=False) + + x = self.conv_proj(x) + if self.conv_fusion is not None: + x = self.conv_fusion(torch.cat((shortcut, x), dim=1)) + return x + + +register_block('mobilevit', MobileViTBlock) + + +def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs): + return build_model_with_cfg( + ByobNet, variant, pretrained, + model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], + feature_cfg=dict(flatten_sequential=True), + **kwargs) + + +@register_model +def mobilevit_xxs(pretrained=False, **kwargs): + return _create_mobilevit('mobilevit_xxs', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevit_xs(pretrained=False, **kwargs): + return _create_mobilevit('mobilevit_xs', pretrained=pretrained, **kwargs) + + +@register_model +def mobilevit_s(pretrained=False, **kwargs): + return _create_mobilevit('mobilevit_s', pretrained=pretrained, **kwargs) From bcaeb91b03e604ab7aa2175f9573bb188e08b392 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 31 Jan 2022 15:41:08 -0800 Subject: [PATCH 2/4] Version to 0.6.0, possible interface incompatibilities vs 0.5.x --- timm/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/version.py b/timm/version.py index 31d29d82..ef7eb44d 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.5.5' +__version__ = '0.6.0' From 2c3870e107968f921f012b63e88432c21da27aad Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 31 Jan 2022 22:36:09 -0800 Subject: [PATCH 3/4] semobilevit_s for good measure --- timm/models/mobilevit.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/timm/models/mobilevit.py b/timm/models/mobilevit.py index 1cf519c2..8000ed2e 100644 --- a/timm/models/mobilevit.py +++ b/timm/models/mobilevit.py @@ -34,19 +34,19 @@ def _cfg(url='', **kwargs): 'crop_pct': 0.9, 'interpolation': 'bicubic', 'mean': (0, 0, 0), 'std': (1, 1, 1), 'first_conv': 'stem.conv', 'classifier': 'head.fc', - 'fixed_input_size': False, 'min_input_size': (3, 256, 256), + 'fixed_input_size': False, **kwargs } default_cfgs = { - # GPU-Efficient (ResNet) weights 'mobilevit_xxs': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_xxs-ad385b40.pth'), 'mobilevit_xs': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_xs-8fbd6366.pth'), 'mobilevit_s': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_s-38a5a959.pth'), + 'semobilevit_s': _cfg(), } @@ -119,6 +119,23 @@ model_cfgs = dict( act_layer='silu', num_features=640, ), + + semobilevit_s=ByoModelCfg( + blocks=( + _inverted_residual_block(d=1, c=32, s=1), + _inverted_residual_block(d=3, c=64, s=2), + _mobilevit_block(d=1, c=96, s=2, transformer_dim=144, transformer_depth=2, patch_size=2), + _mobilevit_block(d=1, c=128, s=2, transformer_dim=192, transformer_depth=4, patch_size=2), + _mobilevit_block(d=1, c=160, s=2, transformer_dim=240, transformer_depth=3, patch_size=2), + ), + stem_chs=16, + stem_type='3x3', + stem_pool='', + downsample='', + attn_layer='se', + attn_kwargs=dict(rd_ratio=1/8), + num_features=640, + ), ) @@ -246,3 +263,8 @@ def mobilevit_xs(pretrained=False, **kwargs): @register_model def mobilevit_s(pretrained=False, **kwargs): return _create_mobilevit('mobilevit_s', pretrained=pretrained, **kwargs) + + +@register_model +def semobilevit_s(pretrained=False, **kwargs): + return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs) \ No newline at end of file From 372ad5fa0dbeb74dcec81db06e9ff69b3d5a2eb6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 28 Feb 2022 13:56:23 -0800 Subject: [PATCH 4/4] Significant model refactor and additions: * All models updated with revised foward_features / forward_head interface * Vision transformer and MLP based models consistently output sequence from forward_features (pooling or token selection considered part of 'head') * WIP param grouping interface to allow consistent grouping of parameters for layer-wise decay across all model types * Add gradient checkpointing support to a significant % of models, especially popular architectures * Formatting and interface consistency improvements across models * layer-wise LR decay impl part of optimizer factory w/ scale support in scheduler * Poolformer and Volo architectures added --- benchmark.py | 5 + tests/test_models.py | 40 +- timm/data/distributed_sampler.py | 8 +- timm/data/loader.py | 41 +- timm/models/__init__.py | 2 + timm/models/beit.py | 168 +++--- timm/models/byobnet.py | 71 ++- timm/models/cait.py | 61 +- timm/models/coat.py | 119 ++-- timm/models/convit.py | 56 +- timm/models/convmixer.py | 57 +- timm/models/convnext.py | 121 ++-- timm/models/crossvit.py | 38 +- timm/models/cspnet.py | 64 ++- timm/models/deit.py | 39 +- timm/models/densenet.py | 22 +- timm/models/dla.py | 92 ++- timm/models/dpn.py | 39 +- timm/models/efficientnet.py | 47 +- timm/models/features.py | 4 +- timm/models/fx_features.py | 4 +- timm/models/ghostnet.py | 45 +- timm/models/gluon_xception.py | 26 +- timm/models/hardcorenas.py | 2 +- timm/models/helpers.py | 206 ++++++- timm/models/hrnet.py | 80 ++- timm/models/inception_resnet_v2.py | 35 +- timm/models/inception_v3.py | 106 ++-- timm/models/inception_v4.py | 21 +- timm/models/layers/attention_pool2d.py | 57 +- timm/models/layers/classifier.py | 10 +- timm/models/layers/evo_norm.py | 8 +- timm/models/layers/pos_embed.py | 207 +++++++ timm/models/levit.py | 278 +++++---- timm/models/mlp_mixer.py | 31 +- timm/models/mobilenetv3.py | 41 +- timm/models/mobilevit.py | 4 +- timm/models/nasnet.py | 30 +- timm/models/nest.py | 55 +- timm/models/nfnet.py | 137 ++--- timm/models/pit.py | 10 +- timm/models/pnasnet.py | 18 +- timm/models/poolformer.py | 322 +++++++++++ timm/models/regnet.py | 70 ++- timm/models/res2net.py | 12 +- timm/models/resnest.py | 11 +- timm/models/resnet.py | 53 +- timm/models/resnetv2.py | 40 +- timm/models/rexnet.py | 41 +- timm/models/selecsls.py | 22 +- timm/models/senet.py | 55 +- timm/models/sknet.py | 29 +- timm/models/swin_transformer.py | 292 ++++++---- timm/models/tnt.py | 50 +- timm/models/tresnet.py | 27 +- timm/models/twins.py | 51 +- timm/models/vgg.py | 51 +- timm/models/visformer.py | 93 +-- timm/models/vision_transformer.py | 134 +++-- timm/models/volo.py | 750 +++++++++++++++++++++++++ timm/models/vovnet.py | 53 +- timm/models/xception.py | 26 +- timm/models/xception_aligned.py | 55 +- timm/models/xcit.py | 63 ++- timm/optim/optim_factory.py | 137 ++++- timm/scheduler/scheduler.py | 5 +- train.py | 28 +- 67 files changed, 3761 insertions(+), 1214 deletions(-) create mode 100644 timm/models/layers/pos_embed.py create mode 100644 timm/models/poolformer.py create mode 100644 timm/models/volo.py diff --git a/benchmark.py b/benchmark.py index 06f23a72..422da45d 100755 --- a/benchmark.py +++ b/benchmark.py @@ -89,6 +89,8 @@ parser.add_argument('--gp', default=None, type=str, metavar='POOL', help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') parser.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') +parser.add_argument('--grad-checkpointing', action='store_true', default=False, + help='Enable gradient checkpointing through model blocks/stages') parser.add_argument('--amp', action='store_true', default=False, help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.') parser.add_argument('--precision', default='float32', type=str, @@ -322,6 +324,9 @@ class TrainBenchmarkRunner(BenchmarkRunner): opt=kwargs.pop('opt', 'sgd'), lr=kwargs.pop('lr', 1e-4)) + if kwargs.pop('grad_checkpointing', False): + self.model.set_grad_checkpointing() + def _gen_target(self, batch_size): return torch.empty( (batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes) diff --git a/tests/test_models.py b/tests/test_models.py index 6b448dc9..11e84027 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -24,7 +24,8 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): # transformer models don't support many of the spatial / feature based model functionalities NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', - 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*'] + 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*', + 'poolformer_*', 'volo_*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures @@ -144,7 +145,7 @@ def test_model_default_cfgs(model_name, batch_size): # test forward_features (always unpooled) outputs = model.forward_features(input_tensor) - assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] + assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2], 'unpooled feature shape != config' # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features model.reset_classifier(0) @@ -156,8 +157,8 @@ def test_model_default_cfgs(model_name, batch_size): model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through outputs = model.forward(input_tensor) assert len(outputs.shape) == 4 - if not isinstance(model, timm.models.MobileNetV3) and not isinstance(model, timm.models.GhostNet): - # FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ + if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)): + # mobilenetv3/ghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] if 'pruned' not in model_name: # FIXME better pruned model handling @@ -165,8 +166,7 @@ def test_model_default_cfgs(model_name, batch_size): model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval() outputs = model.forward(input_tensor) assert len(outputs.shape) == 4 - if not isinstance(model, timm.models.MobileNetV3) and not isinstance(model, timm.models.GhostNet): - # FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ + if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)): assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] # check classifier name matches default_cfg @@ -204,9 +204,11 @@ def test_model_default_cfgs_non_std(model_name, batch_size): outputs = model.forward_features(input_tensor) if isinstance(outputs, (tuple, list)): - outputs = outputs[0] - feat_dim = -1 if outputs.ndim == 3 else 1 - assert outputs.shape[feat_dim] == model.num_features + # cannot currently verify multi-tensor output. + pass + else: + feat_dim = -1 if outputs.ndim == 3 else 1 + assert outputs.shape[feat_dim] == model.num_features # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features model.reset_classifier(0) @@ -214,7 +216,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size): if isinstance(outputs, (tuple, list)): outputs = outputs[0] feat_dim = -1 if outputs.ndim == 3 else 1 - assert outputs.shape[feat_dim] == model.num_features + assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config' model = create_model(model_name, pretrained=False, num_classes=0).eval() outputs = model.forward(input_tensor) @@ -319,13 +321,18 @@ def _create_fx_model(model, train=False): # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names - train_nodes, eval_nodes = get_graph_node_names( - model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) + tracer_kwargs = dict( + leaf_modules=list(_leaf_modules), + autowrap_functions=list(_autowrap_functions), + #enable_cpatching=True, + param_shapes_constant=True + ) + train_nodes, eval_nodes = get_graph_node_names(model, tracer_kwargs=tracer_kwargs) eval_return_nodes = [eval_nodes[-1]] train_return_nodes = [train_nodes[-1]] if train: - tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions)) + tracer = NodePathTracer(**tracer_kwargs) graph = tracer.trace(model) graph_nodes = list(reversed(graph.nodes)) output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()] @@ -334,8 +341,11 @@ def _create_fx_model(model, train=False): train_return_nodes = [train_nodes[ix] for ix in output_node_indices] fx_model = create_feature_extractor( - model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes, - tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) + model, + train_return_nodes=train_return_nodes, + eval_return_nodes=eval_return_nodes, + tracer_kwargs=tracer_kwargs, + ) return fx_model diff --git a/timm/data/distributed_sampler.py b/timm/data/distributed_sampler.py index 1cefc31d..54ff0459 100644 --- a/timm/data/distributed_sampler.py +++ b/timm/data/distributed_sampler.py @@ -108,7 +108,13 @@ class RepeatAugSampler(Sampler): indices = torch.arange(start=0, end=len(self.dataset)) # produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....] - indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist() + if isinstance(self.num_repeats, float) and not self.num_repeats.is_integer(): + # resample for repeats w/ non-integer ratio + repeat_size = math.ceil(self.num_repeats * len(self.dataset)) + indices = indices[torch.tensor([int(i // self.num_repeats) for i in range(repeat_size)])] + else: + indices = torch.repeat_interleave(indices, repeats=int(self.num_repeats), dim=0) + indices = indices.tolist() # leaving as tensor thrashes dataloader memory # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) if padding_size > 0: diff --git a/timm/data/loader.py b/timm/data/loader.py index 67d8cd83..ecc075c0 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -7,6 +7,7 @@ Hacked together by / Copyright 2019, Ross Wightman """ import random from functools import partial +from itertools import repeat from typing import Callable import torch.utils.data @@ -54,20 +55,37 @@ def fast_collate(batch): assert False +def expand_to_chs(x, n): + if not isinstance(x, (tuple, list)): + x = tuple(repeat(x, n)) + elif len(x) == 1: + x = x * n + else: + assert len(x) == n, 'normalization stats must match image channels' + return x + + class PrefetchLoader: - def __init__(self, - loader, - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, - fp16=False, - re_prob=0., - re_mode='const', - re_count=1, - re_num_splits=0): + def __init__( + self, + loader, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + channels=3, + fp16=False, + re_prob=0., + re_mode='const', + re_count=1, + re_num_splits=0): + + mean = expand_to_chs(mean, channels) + std = expand_to_chs(std, channels) + normalization_shape = (1, channels, 1, 1) + self.loader = loader - self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) - self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) + self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(normalization_shape) + self.std = torch.tensor([x * 255 for x in std]).cuda().view(normalization_shape) self.fp16 = fp16 if fp16: self.mean = self.mean.half() @@ -247,6 +265,7 @@ def create_loader( loader, mean=mean, std=std, + channels=input_size[0], fp16=fp16, re_prob=prefetch_re_prob, re_mode=re_mode, diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 35209a2b..497aa53d 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -30,6 +30,7 @@ from .nest import * from .nfnet import * from .pit import * from .pnasnet import * +from .poolformer import * from .regnet import * from .res2net import * from .resnest import * @@ -47,6 +48,7 @@ from .vgg import * from .visformer import * from .vision_transformer import * from .vision_transformer_hybrid import * +from .volo import * from .vovnet import * from .xception import * from .xception_aligned import * diff --git a/timm/models/beit.py b/timm/models/beit.py index 68ca44b7..557ac9b0 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -20,11 +20,12 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below # --------------------------------------------------------' import math from functools import partial -from typing import Optional +from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint from .helpers import build_model_with_cfg from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ @@ -71,6 +72,28 @@ default_cfgs = { } +def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor: + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + # cls to token & token 2 cls & cls to cls + # get pair-wise relative position index for each token inside the window + window_area = window_size[0] * window_size[1] + coords = torch.stack(torch.meshgrid( + [torch.arange(window_size[0]), + torch.arange(window_size[1])])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + return relative_position_index + + class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, attn_drop=0., @@ -98,26 +121,7 @@ class Attention(nn.Module): self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH - # cls to token & token 2 cls & cls to cls - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(window_size[0]) - coords_w = torch.arange(window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * window_size[1] - 1 - relative_position_index = \ - torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) - relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - relative_position_index[0, 0:] = self.num_relative_distance - 3 - relative_position_index[0:, 0] = self.num_relative_distance - 2 - relative_position_index[0, 0] = self.num_relative_distance - 1 - - self.register_buffer("relative_position_index", relative_position_index) + self.register_buffer("relative_position_index", gen_relative_position_index(window_size)) else: self.window_size = None self.relative_position_bias_table = None @@ -127,8 +131,17 @@ class Attention(nn.Module): self.proj = nn.Linear(all_head_dim, dim) self.proj_drop = nn.Dropout(proj_drop) - def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None): + def _get_rel_pos_bias(self): + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + return relative_position_bias.unsqueeze(0) + + def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None): B, N, C = x.shape + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) @@ -138,15 +151,9 @@ class Attention(nn.Module): attn = (q @ k.transpose(-2, -1)) if self.relative_position_bias_table is not None: - relative_position_bias = \ - self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1] + 1, - self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if rel_pos_bias is not None: - attn = attn + rel_pos_bias + attn = attn + self._get_rel_pos_bias() + if shared_rel_pos_bias is not None: + attn = attn + shared_rel_pos_bias attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) @@ -159,9 +166,10 @@ class Attention(nn.Module): class Block(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., - drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, - window_size=None, attn_head_dim=None): + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + window_size=None, attn_head_dim=None): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( @@ -174,17 +182,17 @@ class Block(nn.Module): self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if init_values: - self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) - self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + self.gamma_1 = nn.Parameter(init_values * torch.ones(dim), requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones(dim), requires_grad=True) else: self.gamma_1, self.gamma_2 = None, None - def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None): + def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None): if self.gamma_1 is None: - x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias)) x = x + self.drop_path(self.mlp(self.norm2(x))) else: - x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias)) x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x @@ -194,37 +202,15 @@ class RelativePositionBias(nn.Module): def __init__(self, window_size, num_heads): super().__init__() self.window_size = window_size - self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 - self.relative_position_bias_table = nn.Parameter( - torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH - # cls to token & token 2 cls & cls to cls - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(window_size[0]) - coords_w = torch.arange(window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * window_size[1] - 1 - relative_position_index = \ - torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) - relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - relative_position_index[0, 0:] = self.num_relative_distance - 3 - relative_position_index[0:, 0] = self.num_relative_distance - 2 - relative_position_index[0, 0] = self.num_relative_distance - 1 - - self.register_buffer("relative_position_index", relative_position_index) - + self.window_area = window_size[0] * window_size[1] + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads)) # trunc_normal_(self.relative_position_bias_table, std=.02) + self.register_buffer("relative_position_index", gen_relative_position_index(window_size)) def forward(self): - relative_position_bias = \ - self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1] + 1, - self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_area + 1, self.window_area + 1, -1) # Wh*Ww,Wh*Ww,nH return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww @@ -242,6 +228,7 @@ class Beit(nn.Module): self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.grad_checkpointing = False self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) @@ -258,7 +245,6 @@ class Beit(nn.Module): self.rel_pos_bias = None dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - self.use_rel_pos_bias = use_rel_pos_bias self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, @@ -298,45 +284,63 @@ class Beit(nn.Module): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - def get_num_layers(self): - return len(self.blocks) - @torch.jit.ignore def no_weight_decay(self): - return {'pos_embed', 'cls_token'} + nwd = {'pos_embed', 'cls_token'} + for n, _ in self.named_parameters(): + if 'relative_position_bias_table' in n: + nwd.add(n) + return nwd + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^cls_token|pos_embed|patch_embed|rel_pos_bias', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))], + ) + return matcher + + @torch.jit.ignore def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.patch_embed(x) - batch_size, seq_len, _ = x.size() - - cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks - x = torch.cat((cls_tokens, x), dim=1) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) if self.pos_embed is not None: x = x + self.pos_embed x = self.pos_drop(x) rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None for blk in self.blocks: - x = blk(x, rel_pos_bias=rel_pos_bias) - + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias) + else: + x = blk(x, shared_rel_pos_bias=rel_pos_bias) x = self.norm(x) return x - def forward(self, x): - x = self.forward_features(x) + def forward_head(self, x, pre_logits: bool = False): if self.fc_norm is not None: x = x[:, 1:].mean(dim=1) x = self.fc_norm(x) else: x = x[:, 0] - x = self.head(x) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) return x diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 554b9a6e..18bd53af 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -33,7 +33,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply +from .helpers import build_model_with_cfg, named_apply, checkpoint_seq from .layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\ EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, FilterResponseNormAct2d, FilterResponseNormTlu2d @@ -159,9 +159,9 @@ default_cfgs = { mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv', crop_pct=0.94), - 'regnetz_d8_evob': _cfgr( + 'regnetz_c16_evos': _cfgr( url='', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95), + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.94), 'regnetz_d8_evos': _cfgr( url='', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95), @@ -621,20 +621,19 @@ model_cfgs = dict( attn_kwargs=dict(rd_ratio=0.25), block_kwargs=dict(bottle_in=True, linear_out=True), ), - regnetz_d8_evob=ByoModelCfg( + regnetz_c16_evos=ByoModelCfg( blocks=( - ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4), - ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=8, br=4), - ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=8, br=4), - ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=8, br=4), + ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=4), + ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=4), + ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=4), + ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=4), ), - stem_chs=64, - stem_type='tiered', + stem_chs=32, stem_pool='', downsample='', - num_features=1792, + num_features=1536, act_layer='silu', - norm_layer='evonormb0', + norm_layer=partial(EvoNorm2dS0a, group_size=16), attn_layer='se', attn_kwargs=dict(rd_ratio=0.25), block_kwargs=dict(bottle_in=True, linear_out=True), @@ -888,10 +887,10 @@ def regnetz_b16_evos(pretrained=False, **kwargs): @register_model -def regnetz_d8_evob(pretrained=False, **kwargs): +def regnetz_c16_evos(pretrained=False, **kwargs): """ """ - return _create_byobnet('regnetz_d8_evob', pretrained=pretrained, **kwargs) + return _create_byobnet('regnetz_c16_evos', pretrained=pretrained, **kwargs) @register_model @@ -1200,9 +1199,10 @@ class SelfAttnBlock(nn.Module): """ ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1 """ - def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, - downsample='avg', extra_conv=False, linear_out=False, bottle_in=False, post_attn_na=True, - feat_size=None, layers: LayerFn = None, drop_block=None, drop_path_rate=0.): + def __init__( + self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, + downsample='avg', extra_conv=False, linear_out=False, bottle_in=False, post_attn_na=True, + feat_size=None, layers: LayerFn = None, drop_block=None, drop_path_rate=0.): super(SelfAttnBlock, self).__init__() assert layers is not None mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio) @@ -1269,8 +1269,9 @@ def create_block(block: Union[str, nn.Module], **kwargs): class Stem(nn.Sequential): - def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool', - num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None): + def __init__( + self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool', + num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None): super().__init__() assert stride in (2, 4) layers = layers or LayerFn() @@ -1479,11 +1480,13 @@ class ByobNet(nn.Module): Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act). """ - def __init__(self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, - zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0.): + def __init__( + self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, + zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0.): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate + self.grad_checkpointing = False layers = get_layer_fns(cfg) if cfg.fixed_input_size: assert img_size is not None, 'img_size argument is required for fixed input size model' @@ -1514,6 +1517,22 @@ class ByobNet(nn.Module): # init weights named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', + blocks=[ + (r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).(\d+)', None), + (r'^final_conv', (99999,)) + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): return self.head.fc @@ -1522,13 +1541,19 @@ class ByobNet(nn.Module): def forward_features(self, x): x = self.stem(x) - x = self.stages(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) x = self.final_conv(x) return x + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) + def forward(self, x): x = self.forward_features(x) - x = self.head(x) + x = self.forward_head(x) return x diff --git a/timm/models/cait.py b/timm/models/cait.py index 331111f2..bcc91497 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -9,13 +9,13 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. from copy import deepcopy +from functools import partial import torch import torch.nn as nn -from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, checkpoint_seq from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ from .registry import register_model @@ -202,14 +202,13 @@ class Cait(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to adapt to our cait models def __init__( - self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., - drop_path_rate=0., - norm_layer=partial(nn.LayerNorm, eps=1e-6), - global_pool=None, + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., block_layers=LayerScaleBlock, block_layers_token=LayerScaleBlockClassAttn, patch_layer=PatchEmbed, + norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, attn_block=TalkingHeadAttn, mlp_block=Mlp, @@ -220,9 +219,12 @@ class Cait(nn.Module): mlp_ratio_token_only=4.0 ): super().__init__() + assert global_pool in ('', 'token', 'avg') self.num_classes = num_classes + self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim + self.grad_checkpointing = False self.patch_embed = patch_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) @@ -271,32 +273,61 @@ class Cait(nn.Module): def no_weight_decay(self): return {'pos_embed', 'cls_token'} + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def group_matcher(self, coarse=False): + def _matcher(name): + if any([name.startswith(n) for n in ('cls_token', 'pos_embed', 'patch_embed')]): + return 0 + elif name.startswith('blocks.'): + return int(name.split('.')[1]) + 1 + elif name.startswith('blocks_token_only.'): + # overlap token only blocks with last blocks + to_offset = len(self.blocks) - len(self.blocks_token_only) + 1 + return int(name.split('.')[1]) + to_offset + elif name.startswith('norm.'): + return len(self.blocks) + else: + return float('inf') + return _matcher + + @torch.jit.ignore def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'token', 'avg') + self.global_pool = global_pool self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): - B = x.shape[0] x = self.patch_embed(x) x = x + self.pos_embed x = self.pos_drop(x) - x = self.blocks(x) - - cls_tokens = self.cls_token.expand(B, -1, -1) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) for i, blk in enumerate(self.blocks_token_only): cls_tokens = blk(x, cls_tokens) x = torch.cat((cls_tokens, x), dim=1) - x = self.norm(x) return x + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + return x if pre_logits else self.head(x) + def forward(self, x): x = self.forward_features(x) - x = x[:, 0] - x = self.head(x) + x = self.forward_head(x) return x diff --git a/timm/models/coat.py b/timm/models/coat.py index 4188243f..c3071a6c 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -9,7 +9,7 @@ Modified from timm/models/vision_transformer.py """ from copy import deepcopy from functools import partial -from typing import Tuple, List +from typing import Tuple, List, Union import torch import torch.nn as nn @@ -125,7 +125,7 @@ class ConvRelPosEnc(nn.Module): return EV_hat -class FactorAtt_ConvRelPosEnc(nn.Module): +class FactorAttnConvRelPosEnc(nn.Module): """ Factorized attention with convolutional relative position encoding class. """ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., shared_crpe=None): super().__init__() @@ -205,7 +205,7 @@ class SerialBlock(nn.Module): self.cpe = shared_cpe self.norm1 = norm_layer(dim) - self.factoratt_crpe = FactorAtt_ConvRelPosEnc( + self.factoratt_crpe = FactorAttnConvRelPosEnc( dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpe) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -239,15 +239,15 @@ class ParallelBlock(nn.Module): self.norm12 = norm_layer(dims[1]) self.norm13 = norm_layer(dims[2]) self.norm14 = norm_layer(dims[3]) - self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc( + self.factoratt_crpe2 = FactorAttnConvRelPosEnc( dims[1], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[1] ) - self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc( + self.factoratt_crpe3 = FactorAttnConvRelPosEnc( dims[2], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[2] ) - self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc( + self.factoratt_crpe4 = FactorAttnConvRelPosEnc( dims[3], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[3] ) @@ -328,17 +328,19 @@ class ParallelBlock(nn.Module): class CoaT(nn.Module): """ CoaT class. """ def __init__( - self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=(0, 0, 0, 0), + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=(0, 0, 0, 0), serial_depths=(0, 0, 0, 0), parallel_depth=0, num_heads=0, mlp_ratios=(0, 0, 0, 0), qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), - return_interm_layers=False, out_features=None, crpe_window=None, **kwargs): + return_interm_layers=False, out_features=None, crpe_window=None, global_pool='token'): super().__init__() + assert global_pool in ('token', 'avg') crpe_window = crpe_window or {3: 2, 5: 3, 7: 3} self.return_interm_layers = return_interm_layers self.out_features = out_features self.embed_dims = embed_dims self.num_features = embed_dims[-1] self.num_classes = num_classes + self.global_pool = global_pool # Patch embeddings. img_size = to_2tuple(img_size) @@ -470,61 +472,73 @@ class CoaT(nn.Module): def no_weight_decay(self): return {'cls_token1', 'cls_token2', 'cls_token3', 'cls_token4'} + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem1=r'^cls_token1|patch_embed1|crpe1|cpe1', + serial_blocks1=r'^serial_blocks1\.(\d+)', + stem2=r'^cls_token2|patch_embed2|crpe2|cpe2', + serial_blocks2=r'^serial_blocks2\.(\d+)', + stem3=r'^cls_token3|patch_embed3|crpe3|cpe3', + serial_blocks3=r'^serial_blocks3\.(\d+)', + stem4=r'^cls_token4|patch_embed4|crpe4|cpe4', + serial_blocks4=r'^serial_blocks4\.(\d+)', + parallel_blocks=[ # FIXME (partially?) overlap parallel w/ serial blocks?? + (r'^parallel_blocks\.(\d+)', None), + (r'^norm|aggregate', (99999,)), + ] + ) + return matcher + + @torch.jit.ignore def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('token', 'avg') + self.global_pool = global_pool self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - def insert_cls(self, x, cls_token): - """ Insert CLS token. """ - cls_tokens = cls_token.expand(x.shape[0], -1, -1) - x = torch.cat((cls_tokens, x), dim=1) - return x - - def remove_cls(self, x): - """ Remove CLS token. """ - return x[:, 1:, :] - def forward_features(self, x0): B = x0.shape[0] # Serial blocks 1. x1 = self.patch_embed1(x0) H1, W1 = self.patch_embed1.grid_size - x1 = self.insert_cls(x1, self.cls_token1) + x1 = insert_cls(x1, self.cls_token1) for blk in self.serial_blocks1: x1 = blk(x1, size=(H1, W1)) - x1_nocls = self.remove_cls(x1) - x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() + x1_nocls = remove_cls(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() # Serial blocks 2. x2 = self.patch_embed2(x1_nocls) H2, W2 = self.patch_embed2.grid_size - x2 = self.insert_cls(x2, self.cls_token2) + x2 = insert_cls(x2, self.cls_token2) for blk in self.serial_blocks2: x2 = blk(x2, size=(H2, W2)) - x2_nocls = self.remove_cls(x2) - x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous() + x2_nocls = remove_cls(x2).reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous() # Serial blocks 3. x3 = self.patch_embed3(x2_nocls) H3, W3 = self.patch_embed3.grid_size - x3 = self.insert_cls(x3, self.cls_token3) + x3 = insert_cls(x3, self.cls_token3) for blk in self.serial_blocks3: x3 = blk(x3, size=(H3, W3)) - x3_nocls = self.remove_cls(x3) - x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous() + x3_nocls = remove_cls(x3).reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous() # Serial blocks 4. x4 = self.patch_embed4(x3_nocls) H4, W4 = self.patch_embed4.grid_size - x4 = self.insert_cls(x4, self.cls_token4) + x4 = insert_cls(x4, self.cls_token4) for blk in self.serial_blocks4: x4 = blk(x4, size=(H4, W4)) - x4_nocls = self.remove_cls(x4) - x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous() + x4_nocls = remove_cls(x4).reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous() # Only serial blocks: Early return. if self.parallel_blocks is None: @@ -554,20 +568,16 @@ class CoaT(nn.Module): # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). feat_out = {} if 'x1_nocls' in self.out_features: - x1_nocls = self.remove_cls(x1) - x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() + x1_nocls = remove_cls(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() feat_out['x1_nocls'] = x1_nocls if 'x2_nocls' in self.out_features: - x2_nocls = self.remove_cls(x2) - x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous() + x2_nocls = remove_cls(x2).reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous() feat_out['x2_nocls'] = x2_nocls if 'x3_nocls' in self.out_features: - x3_nocls = self.remove_cls(x3) - x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous() + x3_nocls = remove_cls(x3).reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous() feat_out['x3_nocls'] = x3_nocls if 'x4_nocls' in self.out_features: - x4_nocls = self.remove_cls(x4) - x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous() + x4_nocls = remove_cls(x4).reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous() feat_out['x4_nocls'] = x4_nocls return feat_out else: @@ -576,6 +586,18 @@ class CoaT(nn.Module): x4 = self.norm4(x4) return [x2, x3, x4] + def forward_head(self, x_feat: Union[torch.Tensor, List[torch.Tensor]], pre_logits: bool = False): + if isinstance(x_feat, list): + assert self.aggregate is not None + if self.global_pool == 'avg': + x = torch.cat([xl[:, 1:].mean(dim=1, keepdim=True) for xl in x_feat], dim=1) # [B, 3, C] + else: + x = torch.stack([xl[:, 0] for xl in x_feat], dim=1) # [B, 3, C] + x = self.aggregate(x).squeeze(dim=1) # Shape: [B, C] + else: + x = x_feat[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x_feat[:, 0] + return x if pre_logits else self.head(x) + def forward(self, x) -> torch.Tensor: if not torch.jit.is_scripting() and self.return_interm_layers: # Return intermediate features (for down-stream tasks). @@ -583,15 +605,22 @@ class CoaT(nn.Module): else: # Return features for classification. x_feat = self.forward_features(x) - if isinstance(x_feat, (tuple, list)): - x = torch.cat([xl[:, :1] for xl in x_feat], dim=1) # [B, 3, C] - x = self.aggregate(x).squeeze(dim=1) # Shape: [B, C] - else: - x = x_feat[:, 0] - x = self.head(x) + x = self.forward_head(x_feat) return x +def insert_cls(x, cls_token): + """ Insert CLS token. """ + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + return x + + +def remove_cls(x): + """ Remove CLS token. """ + return x[:, 1:, :] + + def checkpoint_filter_fn(state_dict, model): out_dict = {} for k, v in state_dict.items(): diff --git a/timm/models/convit.py b/timm/models/convit.py index a3287574..26849f6e 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -61,8 +61,8 @@ default_cfgs = { @register_notrace_module # reason: FX can't symbolically trace control flow in forward method class GPSA(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., - locality_strength=1.): + def __init__( + self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., locality_strength=1.): super().__init__() self.num_heads = num_heads self.dim = dim @@ -169,7 +169,7 @@ class MHSA(nn.Module): indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1) indd = indx ** 2 + indy ** 2 distances = indd ** .5 - distances = distances.to('cuda') + distances = distances.to(x.device) dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / N if return_map: @@ -180,7 +180,7 @@ class MHSA(nn.Module): def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] + q, k, v = qkv.unbind(0) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) @@ -194,8 +194,9 @@ class MHSA(nn.Module): class Block(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs): + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs): super().__init__() self.norm1 = norm_layer(dim) self.use_gpsa = use_gpsa @@ -219,13 +220,16 @@ class ConViT(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., - drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None, - local_up_to_layer=3, locality_strength=1., use_pos_embed=True): + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, + local_up_to_layer=3, locality_strength=1., use_pos_embed=True): super().__init__() + assert global_pool in ('', 'avg', 'token') embed_dim *= num_heads self.num_classes = num_classes + self.global_pool = global_pool self.local_up_to_layer = local_up_to_layer self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.locality_strength = locality_strength @@ -285,35 +289,49 @@ class ConViT(nn.Module): def no_weight_decay(self): return {'pos_embed', 'cls_token'} + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'token', 'avg') + self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): - B = x.shape[0] x = self.patch_embed(x) - - cls_tokens = self.cls_token.expand(B, -1, -1) - if self.use_pos_embed: x = x + self.pos_embed x = self.pos_drop(x) - + cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) for u, blk in enumerate(self.blocks): if u == self.local_up_to_layer: x = torch.cat((cls_tokens, x), dim=1) x = blk(x) - x = self.norm(x) return x + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + return x if pre_logits else self.head(x) + def forward(self, x): x = self.forward_features(x) - x = x[:, 0] - x = self.head(x) + x = self.forward_head(x) return x diff --git a/timm/models/convmixer.py b/timm/models/convmixer.py index f4eb9795..e7e2481a 100644 --- a/timm/models/convmixer.py +++ b/timm/models/convmixer.py @@ -1,7 +1,13 @@ +""" ConvMixer + +""" +import torch import torch.nn as nn + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.registry import register_model -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, checkpoint_seq +from .layers import SelectAdaptivePool2d def _cfg(url='', **kwargs): @@ -32,49 +38,68 @@ class Residual(nn.Module): class ConvMixer(nn.Module): - def __init__(self, dim, depth, kernel_size=9, patch_size=7, in_chans=3, num_classes=1000, activation=nn.GELU, **kwargs): + def __init__( + self, dim, depth, kernel_size=9, patch_size=7, in_chans=3, num_classes=1000, global_pool='avg', + act_layer=nn.GELU, **kwargs): super().__init__() self.num_classes = num_classes self.num_features = dim - self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity() + self.grad_checkpointing = False + self.stem = nn.Sequential( nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size), - activation(), + act_layer(), nn.BatchNorm2d(dim) ) self.blocks = nn.Sequential( *[nn.Sequential( Residual(nn.Sequential( nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), - activation(), + act_layer(), nn.BatchNorm2d(dim) )), nn.Conv2d(dim, dim, kernel_size=1), - activation(), + act_layer(), nn.BatchNorm2d(dim) ) for i in range(depth)] ) - self.pooling = nn.Sequential( - nn.AdaptiveAvgPool2d((1, 1)), - nn.Flatten() - ) + self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity() + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict(stem=r'^stem', blocks=r'^blocks\.(\d+)') + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes + if global_pool is not None: + self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.stem(x) - x = self.blocks(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) return x - + + def forward_head(self, x, pre_logits: bool = False): + x = self.pooling(x) + return x if pre_logits else self.head(x) + def forward(self, x): x = self.forward_features(x) - x = self.pooling(x) - x = self.head(x) + x = self.forward_head(x) return x @@ -90,7 +115,7 @@ def convmixer_1536_20(pretrained=False, **kwargs): @register_model def convmixer_768_32(pretrained=False, **kwargs): - model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, activation=nn.ReLU, **kwargs) + model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, act_layer=nn.ReLU, **kwargs) return _create_convmixer('convmixer_768_32', pretrained, **model_args) diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 8f0b9e0a..0a2df3de 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .fx_features import register_notrace_module -from .helpers import named_apply, build_model_with_cfg +from .helpers import named_apply, build_model_with_cfg, checkpoint_seq from .layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp from .registry import register_model @@ -43,6 +43,7 @@ default_cfgs = dict( convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"), convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), + convnext_nano_hnf=_cfg(url=''), convnext_tiny_hnf=_cfg(url=''), convnext_base_in22ft1k=_cfg( @@ -151,6 +152,7 @@ class ConvNeXtStage(nn.Module): self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=False, norm_layer=None, cl_norm_layer=None, cross_stage=False): super().__init__() + self.grad_checkpointing = False if in_chs != out_chs or stride > 1: self.downsample = nn.Sequential( @@ -169,7 +171,10 @@ class ConvNeXtStage(nn.Module): def forward(self, x): x = self.downsample(x) - x = self.blocks(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) return x @@ -190,7 +195,7 @@ class ConvNeXt(nn.Module): def __init__( self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4, - depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=False, + depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=False, stem_type='patch', head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0., ): super().__init__() @@ -208,19 +213,29 @@ class ConvNeXt(nn.Module): self.feature_info = [] # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 - self.stem = nn.Sequential( - nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size), - norm_layer(dims[0]) - ) + if stem_type == 'patch': + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size), + norm_layer(dims[0]) + ) + curr_stride = patch_size + prev_chs = dims[0] + else: + self.stem = nn.Sequential( + nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1), + norm_layer(32), + nn.GELU(), + nn.Conv2d(32, 64, kernel_size=3, padding=1), + ) + curr_stride = 2 + prev_chs = 64 self.stages = nn.Sequential() dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] - curr_stride = patch_size - prev_chs = dims[0] stages = [] # 4 feature resolution stages, each consisting of multiple residual blocks for i in range(4): - stride = 2 if i > 0 else 1 + stride = 2 if curr_stride == 2 or i > 0 else 1 # FIXME support dilation / output_stride curr_stride *= stride out_chs = dims[i] @@ -235,40 +250,43 @@ class ConvNeXt(nn.Module): self.stages = nn.Sequential(*stages) self.num_features = prev_chs - if head_norm_first: - # norm -> global pool -> fc ordering, like most other nets (not compat with FB weights) - self.norm_pre = norm_layer(self.num_features) # final norm layer, before pooling - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) - else: - # pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) - self.norm_pre = nn.Identity() - self.head = nn.Sequential(OrderedDict([ + # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets + # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) + self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() + self.head = nn.Sequential(OrderedDict([ ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), - ('norm', norm_layer(self.num_features)), + ('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)), ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), ('drop', nn.Dropout(self.drop_rate)), - ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) - ])) + ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+)\.downsample', (0,)), # blocks + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + (r'^norm_pre', (99999,)) + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): return self.head.fc - def reset_classifier(self, num_classes=0, global_pool='avg'): - if isinstance(self.head, ClassifierHead): - # norm -> global pool -> fc - self.head = ClassifierHead( - self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) - else: - # pool -> norm -> fc - self.head = nn.Sequential(OrderedDict([ - ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), - ('norm', self.head.norm), - ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), - ('drop', nn.Dropout(self.drop_rate)), - ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) - ])) + def reset_classifier(self, num_classes=0, global_pool=None): + if global_pool is not None: + self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() + self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.stem(x) @@ -276,9 +294,17 @@ class ConvNeXt(nn.Module): x = self.norm_pre(x) return x + def forward_head(self, x, pre_logits: bool = False): + # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( + x = self.head.global_pool(x) + x = self.head.norm(x) + x = self.head.flatten(x) + x = self.head.drop(x) + return x if pre_logits else self.head.fc(x) + def forward(self, x): x = self.forward_features(x) - x = self.head(x) + x = self.forward_head(x) return x @@ -326,19 +352,34 @@ def _create_convnext(variant, pretrained=False, **kwargs): @register_model -def convnext_tiny(pretrained=False, **kwargs): - model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) - model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args) +def convnext_nano_hnf(pretrained=False, **kwargs): + model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs) + model = _create_convnext('convnext_nano_hnf', pretrained=pretrained, **model_args) return model @register_model def convnext_tiny_hnf(pretrained=False, **kwargs): - model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, **kwargs) + model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) + model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_tiny_hnfd(pretrained=False, **kwargs): + model_args = dict( + depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, stem_type='dual', **kwargs) model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args) return model +@register_model +def convnext_tiny(pretrained=False, **kwargs): + model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) + model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args) + return model + + @register_model def convnext_small(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 653da40b..5a3260bf 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -175,7 +175,6 @@ class CrossAttentionBlock(nn.Module): def forward(self, x): x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x))) - return x @@ -289,12 +288,14 @@ class CrossViT(nn.Module): def __init__( self, img_size=224, img_scale=(1.0, 1.0), patch_size=(8, 16), in_chans=3, num_classes=1000, embed_dim=(192, 384), depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)), num_heads=(6, 12), mlp_ratio=(2., 2., 4.), - qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=False, crop_scale=False, + multi_conv=False, crop_scale=False, qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), global_pool='token', ): super().__init__() + assert global_pool in ('token', 'avg') self.num_classes = num_classes + self.global_pool = global_pool self.img_size = to_2tuple(img_size) img_scale = to_2tuple(img_scale) self.img_size_scaled = [tuple([int(sj * si) for sj in self.img_size]) for si in img_scale] @@ -302,7 +303,7 @@ class CrossViT(nn.Module): num_patches = _compute_num_patches(self.img_size_scaled, patch_size) self.num_branches = len(patch_size) self.embed_dim = embed_dim - self.num_features = embed_dim[0] # to pass the tests + self.num_features = sum(embed_dim) self.patch_embed = nn.ModuleList() # hard-coded for torch jit script @@ -359,11 +360,26 @@ class CrossViT(nn.Module): out.add(f'pos_embed_{i}') return out + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('token', 'avg') + self.global_pool = global_pool self.head = nn.ModuleList( [nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)]) @@ -391,12 +407,16 @@ class CrossViT(nn.Module): xs = [norm(xs[i]) for i, norm in enumerate(self.norm)] return xs + def forward_head(self, xs: List[torch.Tensor], pre_logits: bool = False) -> torch.Tensor: + xs = [x[:, 1:].mean(dim=1) for x in xs] if self.global_pool == 'avg' else [x[:, 0] for x in xs] + if pre_logits or isinstance(self.head[0], nn.Identity): + return torch.cat([x for x in xs], dim=1) + return torch.mean(torch.stack([head(xs[i]) for i, head in enumerate(self.head)], dim=0), dim=0) + def forward(self, x): xs = self.forward_features(x) - ce_logits = [head(xs[i][:, 0]) for i, head in enumerate(self.head)] - if not isinstance(self.head[0], nn.Identity): - ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0) - return ce_logits + x = self.forward_head(xs) + return x def _create_crossvit(variant, pretrained=False, **kwargs): diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 26c92389..75c525bf 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -12,11 +12,13 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage Hacked together by / Copyright 2020 Ross Wightman """ +from functools import partial + import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, get_norm_act_layer from .registry import register_model @@ -172,7 +174,7 @@ class ResBottleneck(nn.Module): self.drop_path = drop_path self.act3 = act_layer(inplace=True) - def zero_init_last_bn(self): + def zero_init_last(self): nn.init.zeros_(self.conv3.bn.weight) def forward(self, x): @@ -210,7 +212,7 @@ class DarkBlock(nn.Module): self.attn = create_attn(attn_layer, channels=out_chs) self.drop_path = drop_path - def zero_init_last_bn(self): + def zero_init_last(self): nn.init.zeros_(self.conv2.bn.weight) def forward(self, x): @@ -345,9 +347,10 @@ class CspNet(nn.Module): darknet impl. I did it this way for simplicity and less special cases. """ - def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0., - act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_path_rate=0., - zero_init_last_bn=True, stage_fn=CrossStage, block_fn=ResBottleneck): + def __init__( + self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0., + act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_path_rate=0., + zero_init_last=True, stage_fn=CrossStage, block_fn=ResBottleneck): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate @@ -378,20 +381,25 @@ class CspNet(nn.Module): self.head = ClassifierHead( in_chs=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, nn.BatchNorm2d): - nn.init.ones_(m.weight) - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Linear): - nn.init.normal_(m.weight, mean=0.0, std=0.01) - nn.init.zeros_(m.bias) - if zero_init_last_bn: - for m in self.modules(): - if hasattr(m, 'zero_init_last_bn'): - m.zero_init_last_bn() + named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', + blocks=r'^stages.(\d+)' if coarse else [ + (r'^stages.(\d+).blocks.(\d+)', None), + (r'^stages.(\d+).*transition', MATCH_PREV_GROUP), # map to last block in stage + (r'^stages.(\d+)', (0,)), + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + @torch.jit.ignore def get_classifier(self): return self.head.fc @@ -403,12 +411,28 @@ class CspNet(nn.Module): x = self.stages(x) return x + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) + def forward(self, x): x = self.forward_features(x) - x = self.head(x) + x = self.forward_head(x) return x +def _init_weights(module, name, zero_init_last=False): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(module, nn.BatchNorm2d): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=0.01) + nn.init.zeros_(module.bias) + elif zero_init_last and hasattr(module, 'zero_init_last'): + module.zero_init_last() + + def _create_cspnet(variant, pretrained=False, **kwargs): cfg_variant = variant.split('_')[0] # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5] diff --git a/timm/models/deit.py b/timm/models/deit.py index 5cb49394..3fd8655b 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -13,7 +13,7 @@ from torch import nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, checkpoint_seq from .registry import register_model @@ -66,10 +66,13 @@ class VisionTransformerDistilled(VisionTransformer): def __init__(self, *args, **kwargs): weight_init = kwargs.pop('weight_init', '') super().__init__(*args, **kwargs, weight_init='skip') + assert self.global_pool in ('token',) + self.num_tokens = 2 self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim)) self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() + self.distilled_training = False self.init_weights(weight_init) @@ -77,32 +80,50 @@ class VisionTransformerDistilled(VisionTransformer): trunc_normal_(self.dist_token, std=.02) super().init_weights(mode=mode) + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^cls_token|pos_embed|patch_embed|dist_token', + blocks=[ + (r'^blocks.(\d+)', None), + (r'^norm', (99999,))] # final norm w/ last block + ) + + @torch.jit.ignore def get_classifier(self): return self.head, self.head_dist - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + @torch.jit.ignore + def set_distilled_training(self, enable=True): + self.distilled_training = enable + def forward_features(self, x) -> torch.Tensor: x = self.patch_embed(x) x = torch.cat(( self.cls_token.expand(x.shape[0], -1, -1), self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) x = self.pos_drop(x + self.pos_embed) - x = self.blocks(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) x = self.norm(x) return x - def forward(self, x): - x = self.forward_features(x) - x_dist = self.head_dist(x[:, 1]) - x = self.head(x[:, 0]) - if self.training and not torch.jit.is_scripting(): + def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor: + if pre_logits: + return (x[:, 0] + x[:, 1]) / 2 + x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1]) + if self.distilled_training and self.training and not torch.jit.is_scripting(): + # only return separate classification predictions when training in distilled mode return x, x_dist else: - # during inference, return the average of both classifier predictions + # during standard train / finetune, inference average the classifier predictions return (x + x_dist) / 2 diff --git a/timm/models/densenet.py b/timm/models/densenet.py index ee66666e..304eda79 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -13,7 +13,7 @@ import torch.utils.checkpoint as cp from torch.jit.annotations import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, MATCH_PREV_GROUP from .layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier from .registry import register_model @@ -162,10 +162,10 @@ class DenseNet(nn.Module): but slower. Default: *False*. See `"paper" `_ """ - def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), bn_size=4, stem_type='', - num_classes=1000, in_chans=3, global_pool='avg', - norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False, - aa_stem_only=True): + def __init__( + self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000, in_chans=3, global_pool='avg', + bn_size=4, stem_type='', norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False, + aa_stem_only=True): self.num_classes = num_classes self.drop_rate = drop_rate super(DenseNet, self).__init__() @@ -249,6 +249,18 @@ class DenseNet(nn.Module): elif isinstance(m, nn.Linear): nn.init.constant_(m.bias, 0) + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^features.conv[012]|features.norm[012]|features.pool[012]', + blocks=r'^features.(?:denseblock|transition)(\d+)' if coarse else [ + (r'^features.denseblock(\d+).denselayer(\d+)', None), + (r'^features.transition(\d+)', MATCH_PREV_GROUP) # FIXME combine with previous denselayer + ] + ) + return matcher + + @torch.jit.ignore def get_classifier(self): return self.classifier diff --git a/timm/models/dla.py b/timm/models/dla.py index 2d8597a5..bc1a7394 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -6,6 +6,7 @@ Res2Net additions from: https://github.com/gasvn/Res2Net/ Res2Net Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169 """ import math +from typing import List, Optional import torch import torch.nn as nn @@ -62,7 +63,7 @@ class DlaBasic(nn.Module): self.bn2 = nn.BatchNorm2d(planes) self.stride = stride - def forward(self, x, shortcut=None): + def forward(self, x, shortcut=None, children: Optional[List[torch.Tensor]] = None): if shortcut is None: shortcut = x @@ -99,7 +100,7 @@ class DlaBottleneck(nn.Module): self.bn3 = nn.BatchNorm2d(outplanes) self.relu = nn.ReLU(inplace=True) - def forward(self, x, shortcut=None): + def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None): if shortcut is None: shortcut = x @@ -147,14 +148,13 @@ class DlaBottle2neck(nn.Module): bns.append(nn.BatchNorm2d(mid_planes)) self.convs = nn.ModuleList(convs) self.bns = nn.ModuleList(bns) - if self.is_first: - self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) + self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) if self.is_first else None self.conv3 = nn.Conv2d(mid_planes * scale, outplanes, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(outplanes) self.relu = nn.ReLU(inplace=True) - def forward(self, x, shortcut=None): + def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None): if shortcut is None: shortcut = x @@ -164,14 +164,21 @@ class DlaBottle2neck(nn.Module): spx = torch.split(out, self.width, 1) spo = [] + sp = spx[0] # redundant, for torchscript for i, (conv, bn) in enumerate(zip(self.convs, self.bns)): - sp = spx[i] if i == 0 or self.is_first else sp + spx[i] + if i == 0 or self.is_first: + sp = spx[i] + else: + sp = sp + spx[i] sp = conv(sp) sp = bn(sp) sp = self.relu(sp) spo.append(sp) if self.scale > 1: - spo.append(self.pool(spx[-1]) if self.is_first else spx[-1]) + if self.pool is not None: # self.is_first == True, None check for torchscript + spo.append(self.pool(spx[-1])) + else: + spo.append(spx[-1]) out = torch.cat(spo, 1) out = self.conv3(out) @@ -192,21 +199,20 @@ class DlaRoot(nn.Module): self.relu = nn.ReLU(inplace=True) self.shortcut = shortcut - def forward(self, *x): - children = x - x = self.conv(torch.cat(x, 1)) + def forward(self, x_children: List[torch.Tensor]): + x = self.conv(torch.cat(x_children, 1)) x = self.bn(x) if self.shortcut: - x += children[0] + x += x_children[0] x = self.relu(x) return x class DlaTree(nn.Module): - def __init__(self, levels, block, in_channels, out_channels, stride=1, - dilation=1, cardinality=1, base_width=64, - level_root=False, root_dim=0, root_kernel_size=1, root_shortcut=False): + def __init__( + self, levels, block, in_channels, out_channels, stride=1, dilation=1, cardinality=1, + base_width=64, level_root=False, root_dim=0, root_kernel_size=1, root_shortcut=False): super(DlaTree, self).__init__() if root_dim == 0: root_dim = 2 * out_channels @@ -225,38 +231,39 @@ class DlaTree(nn.Module): self.project = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d(out_channels)) + self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut) else: cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut)) self.tree1 = DlaTree( levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs) self.tree2 = DlaTree( levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs) - if levels == 1: - self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut) + self.root = None self.level_root = level_root self.root_dim = root_dim self.levels = levels - def forward(self, x, shortcut=None, children=None): - children = [] if children is None else children + def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None): + if children is None: + children = [] bottom = self.downsample(x) shortcut = self.project(bottom) if self.level_root: children.append(bottom) x1 = self.tree1(x, shortcut) - if self.levels == 1: + if self.root is not None: # levels == 1 x2 = self.tree2(x1) - x = self.root(x2, x1, *children) + x = self.root([x2, x1] + children) else: children.append(x1) - x = self.tree2(x1, children=children) + x = self.tree2(x1, None, children) return x class DLA(nn.Module): - def __init__(self, levels, channels, output_stride=32, num_classes=1000, in_chans=3, - cardinality=1, base_width=64, block=DlaBottle2neck, shortcut_root=False, - drop_rate=0.0, global_pool='avg'): + def __init__( + self, levels, channels, output_stride=32, num_classes=1000, in_chans=3, global_pool='avg', + cardinality=1, base_width=64, block=DlaBottle2neck, shortcut_root=False, drop_rate=0.0): super(DLA, self).__init__() self.channels = channels self.num_classes = num_classes @@ -302,13 +309,32 @@ class DLA(nn.Module): modules = [] for i in range(convs): modules.extend([ - nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride if i == 0 else 1, - padding=dilation, bias=False, dilation=dilation), + nn.Conv2d( + inplanes, planes, kernel_size=3, stride=stride if i == 0 else 1, + padding=dilation, bias=False, dilation=dilation), nn.BatchNorm2d(planes), nn.ReLU(inplace=True)]) inplanes = planes return nn.Sequential(*modules) + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^base_layer', + blocks=r'^level(\d+)' if coarse else [ + # an unusual arch, this achieves somewhat more granularity without getting super messy + (r'^level(\d+).tree(\d+)', None), + (r'^level(\d+).root', (2,)), + (r'^level(\d+)', (1,)) + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore def get_classifier(self): return self.fc @@ -328,13 +354,19 @@ class DLA(nn.Module): x = self.level5(x) return x - def forward(self, x): - x = self.forward_features(x) + def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) - x = self.fc(x) - x = self.flatten(x) + if pre_logits: + return x.flatten(1) + else: + x = self.fc(x) + return self.flatten(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) return x diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 79358695..616efdbb 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -166,16 +166,17 @@ class DualPathBlock(nn.Module): class DPN(nn.Module): - def __init__(self, small=False, num_init_features=64, k_r=96, groups=32, - b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), output_stride=32, - num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', fc_act=nn.ELU): + def __init__( + self, small=False, num_init_features=64, k_r=96, groups=32, global_pool='avg', + b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), output_stride=32, + num_classes=1000, in_chans=3, drop_rate=0., fc_act_layer=nn.ELU): super(DPN, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate self.b = b assert output_stride == 32 # FIXME look into dilation support norm_layer = partial(BatchNormAct2d, eps=.001) - fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act, inplace=False) + fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act_layer, inplace=False) bw_factor = 1 if small else 4 blocks = OrderedDict() @@ -239,6 +240,22 @@ class DPN(nn.Module): self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^features.conv1', + blocks=[ + (r'^features.conv(\d+)' if coarse else r'^features.conv(\d+)_(\d+)', None), + (r'^features.conv5_bn_ac', (99999,)) + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore def get_classifier(self): return self.classifier @@ -251,13 +268,19 @@ class DPN(nn.Module): def forward_features(self, x): return self.features(x) - def forward(self, x): - x = self.forward_features(x) + def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) - x = self.classifier(x) - x = self.flatten(x) + if pre_logits: + return x.flatten(1) + else: + x = self.classifier(x) + return self.flatten(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) return x diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 7272468a..96c30c76 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -48,7 +48,7 @@ from .efficientnet_blocks import SqueezeExcite from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .features import FeatureInfo, FeatureHooks -from .helpers import build_model_with_cfg, pretrained_cfg_for_features +from .helpers import build_model_with_cfg, pretrained_cfg_for_features, checkpoint_seq from .layers import create_conv2d, create_classifier, get_norm_act_layer, EvoNorm2dS0, GroupNormAct from .registry import register_model @@ -470,9 +470,10 @@ class EfficientNet(nn.Module): * TinyNet """ - def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, fix_stem=False, - output_stride=32, pad_type='', round_chs_fn=round_channels, act_layer=None, norm_layer=None, - se_layer=None, drop_rate=0., drop_path_rate=0., global_pool='avg'): + def __init__( + self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, fix_stem=False, + output_stride=32, pad_type='', round_chs_fn=round_channels, act_layer=None, norm_layer=None, + se_layer=None, drop_rate=0., drop_path_rate=0., global_pool='avg'): super(EfficientNet, self).__init__() act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d @@ -481,6 +482,7 @@ class EfficientNet(nn.Module): self.num_classes = num_classes self.num_features = num_features self.drop_rate = drop_rate + self.grad_checkpointing = False # Stem if not fix_stem: @@ -511,6 +513,21 @@ class EfficientNet(nn.Module): layers.extend([nn.Dropout(self.drop_rate), self.classifier]) return nn.Sequential(*layers) + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^conv_stem|bn1', + blocks=[ + (r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)', None), + (r'conv_head|bn2', (99999,)) + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): return self.classifier @@ -522,17 +539,24 @@ class EfficientNet(nn.Module): def forward_features(self, x): x = self.conv_stem(x) x = self.bn1(x) - x = self.blocks(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x, flatten=True) + else: + x = self.blocks(x) x = self.conv_head(x) x = self.bn2(x) return x - def forward(self, x): - x = self.forward_features(x) + def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) - return self.classifier(x) + return x if pre_logits else self.classifier(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x class EfficientNetFeatures(nn.Module): @@ -542,9 +566,10 @@ class EfficientNetFeatures(nn.Module): and object detection models. """ - def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, - stem_size=32, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels, - act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.): + def __init__( + self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, + stem_size=32, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels, + act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.): super(EfficientNetFeatures, self).__init__() act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d diff --git a/timm/models/features.py b/timm/models/features.py index b1d6890f..0bc46419 100644 --- a/timm/models/features.py +++ b/timm/models/features.py @@ -86,7 +86,7 @@ class FeatureHooks: This module helps with the setup and extraction of hooks for extracting features from internal nodes in a model by node name. This works quite well in eager Python but needs - redesign for torcscript. + redesign for torchscript. """ def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'): @@ -97,7 +97,7 @@ class FeatureHooks: m = modules[hook_name] hook_id = out_map[i] if out_map else hook_name hook_fn = partial(self._collect_output_hook, hook_id) - hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type + hook_type = h.get('hook_type', default_hook_type) if hook_type == 'forward_pre': m.register_forward_pre_hook(hook_fn) elif hook_type == 'forward': diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index c7ca0f8b..cbb51980 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -89,13 +89,13 @@ class FeatureGraphNet(nn.Module): return list(self.graph_module(x).values()) -class FeatureExtractNet(nn.Module): +class GraphExtractNet(nn.Module): """ A standalone feature extraction wrapper that maps dict -> list or single tensor NOTE: * one can use feature_extractor directly if dictionary output is desired * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info metadata for builtin feature extraction mode - * feature_extractor can be used directly if dictionary output is acceptable + * create_feature_extractor can be used directly if dictionary output is acceptable Args: model: model to extract features from diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 684d6651..bedb04a5 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -15,7 +15,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .layers import SelectAdaptivePool2d, Linear, make_divisible from .efficientnet_blocks import SqueezeExcite, ConvBnAct -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, checkpoint_seq from .registry import register_model @@ -24,7 +24,7 @@ __all__ = ['GhostNet'] def _cfg(url='', **kwargs): return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1), + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'conv_stem', 'classifier': 'classifier', @@ -133,13 +133,15 @@ class GhostBottleneck(nn.Module): class GhostNet(nn.Module): - def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32, global_pool='avg'): + def __init__( + self, cfgs, num_classes=1000, width=1.0, in_chans=3, output_stride=32, global_pool='avg', drop_rate=0.2): super(GhostNet, self).__init__() # setting of inverted residual blocks assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported' self.cfgs = cfgs self.num_classes = num_classes - self.dropout = dropout + self.drop_rate = drop_rate + self.grad_checkpointing = False self.feature_info = [] # building first layer @@ -184,6 +186,24 @@ class GhostNet(nn.Module): self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity() + # FIXME init + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^conv_stem|bn1', + blocks=[ + (r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)', None), + (r'conv_head', (99999,)) + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): return self.classifier @@ -198,18 +218,25 @@ class GhostNet(nn.Module): x = self.conv_stem(x) x = self.bn1(x) x = self.act1(x) - x = self.blocks(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x, flatten=True) + else: + x = self.blocks(x) + return x + + def forward_head(self, x): x = self.global_pool(x) x = self.conv_head(x) x = self.act2(x) + x = self.flatten(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classifier(x) return x def forward(self, x): x = self.forward_features(x) - x = self.flatten(x) - if self.dropout > 0.: - x = F.dropout(x, p=self.dropout, training=self.training) - x = self.classifier(x) + x = self.forward_head(x) return x diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index 6a2168c9..17a197a0 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -8,6 +8,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ from collections import OrderedDict +import torch import torch.nn as nn import torch.nn.functional as F @@ -178,6 +179,23 @@ class Xception65(nn.Module): self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^conv[12]|bn[12]', + blocks=[ + (r'^mid.block(\d+)', None), + (r'^block(\d+)', None), + (r'^conv[345]|bn[345]', (99,)), + ], + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, "gradient checkpointing not supported" + + @torch.jit.ignore def get_classifier(self): return self.fc @@ -222,14 +240,18 @@ class Xception65(nn.Module): x = self.act5(x) return x - def forward(self, x): - x = self.forward_features(x) + def forward_head(self, x): x = self.global_pool(x) if self.drop_rate: F.dropout(x, self.drop_rate, training=self.training) x = self.fc(x) return x + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + def _create_gluon_xception(variant, pretrained=False, **kwargs): return build_model_with_cfg( diff --git a/timm/models/hardcorenas.py b/timm/models/hardcorenas.py index a4e42de0..9e1c2a9b 100644 --- a/timm/models/hardcorenas.py +++ b/timm/models/hardcorenas.py @@ -13,7 +13,7 @@ from .registry import register_model def _cfg(url='', **kwargs): return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1), + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'conv_stem', 'classifier': 'classifier', diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 85be3377..dd45e5ea 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -2,16 +2,20 @@ Hacked together by / Copyright 2020 Ross Wightman """ +import collections.abc import logging -import os import math -from collections import OrderedDict +import os +import re +from collections import OrderedDict, defaultdict from copy import deepcopy -from typing import Any, Callable, Optional, Tuple, Dict +from itertools import chain +from typing import Any, Callable, Optional, Tuple, Dict, Union import torch import torch.nn as nn from torch.hub import load_state_dict_from_url +from torch.utils.checkpoint import checkpoint from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .fx_features import FeatureGraphNet @@ -68,7 +72,8 @@ def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True): raise NotImplementedError('Model cannot load numpy checkpoint') return state_dict = load_state_dict(checkpoint_path, use_ema) - model.load_state_dict(state_dict, strict=strict) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): @@ -479,7 +484,7 @@ def build_model_with_cfg( pretrained_cfg: Optional[Dict] = None, model_cfg: Optional[Any] = None, feature_cfg: Optional[Dict] = None, - pretrained_strict: bool = True, + pretrained_strict: bool = False, pretrained_filter_fn: Optional[Callable] = None, pretrained_custom_load: bool = False, kwargs_filter: Optional[Tuple[str]] = None, @@ -592,3 +597,194 @@ def named_modules(module: nn.Module, name='', depth_first=True, include_root=Fal module=child_module, name=child_name, depth_first=depth_first, include_root=True) if depth_first and include_root: yield name, module + + +def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False): + if module._parameters and not depth_first and include_root: + yield name, module + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + yield from named_modules_with_params( + module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if module._parameters and depth_first and include_root: + yield name, module + + +MATCH_PREV_GROUP = (99999,) + + +def group_with_matcher( + named_objects, + group_matcher: Union[Dict, Callable], + output_values: bool = False, + reverse: bool = False +): + if isinstance(group_matcher, dict): + # dictionary matcher contains a dict of raw-string regex expr that must be compiled + compiled = [] + for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()): + if mspec is None: + continue + # map all matching specifications into 3-tuple (compiled re, prefix, suffix) + if isinstance(mspec, (tuple, list)): + # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix) + for sspec in mspec: + compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])] + else: + compiled += [(re.compile(mspec), (group_ordinal,), None)] + group_matcher = compiled + + def _get_grouping(name): + if isinstance(group_matcher, (list, tuple)): + for match_fn, prefix, suffix in group_matcher: + r = match_fn.match(name) + if r: + parts = (prefix, r.groups(), suffix) + # map all tuple elem to int for numeric sort, filter out None entries + return tuple(map(float, chain.from_iterable(filter(None, parts)))) + return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal + else: + ord = group_matcher(name) + if not isinstance(ord, collections.abc.Iterable): + return ord, + return tuple(ord) + + # map layers into groups via ordinals (ints or tuples of ints) from matcher + grouping = defaultdict(list) + for k, v in named_objects: + grouping[_get_grouping(k)].append(v if output_values else k) + + # remap to integers + layer_id_to_param = defaultdict(list) + lid = -1 + for k in sorted(filter(lambda x: x is not None, grouping.keys())): + if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]: + lid += 1 + print(lid, k, grouping[k]) + layer_id_to_param[lid].extend(grouping[k]) + + if reverse: + assert not output_values, "reverse mapping only sensible for name output" + # output reverse mapping + param_to_layer_id = {} + for lid, lm in layer_id_to_param.items(): + for n in lm: + param_to_layer_id[n] = lid + return param_to_layer_id + + return layer_id_to_param + + +def group_parameters( + module: nn.Module, + group_matcher, + output_values=False, + reverse=False, +): + return group_with_matcher( + module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse) + + +def group_modules( + module: nn.Module, + group_matcher, + output_values=False, + reverse=False, +): + return group_with_matcher( + named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse) + + +def checkpoint_seq( + functions, + x, + every=1, + flatten=False, + skip_last=False, + preserve_rng_state=True +): + r"""A helper function for checkpointing sequential models. + + Sequential models execute a list of modules/functions in order + (sequentially). Therefore, we can divide such a sequence into segments + and checkpoint each segment. All segments except run in :func:`torch.no_grad` + manner, i.e., not storing the intermediate activations. The inputs of each + checkpointed segment will be saved for re-running the segment in the backward pass. + + See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. + + .. warning:: + Checkpointing currently only supports :func:`torch.autograd.backward` + and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` + is not supported. + + .. warning: + At least one of the inputs needs to have :code:`requires_grad=True` if + grads are needed for model inputs, otherwise the checkpointed part of the + model won't have gradients. + + Args: + functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially. + x: A Tensor that is input to :attr:`functions` + every: checkpoint every-n functions (default: 1) + flatten (bool): flatten nn.Sequential of nn.Sequentials + skip_last (bool): skip checkpointing the last function in the sequence if True + preserve_rng_state (bool, optional, default=True): Omit stashing and restoring + the RNG state during each checkpoint. + + Returns: + Output of running :attr:`functions` sequentially on :attr:`*inputs` + + Example: + >>> model = nn.Sequential(...) + >>> input_var = checkpoint_seq(model, input_var, every=2) + """ + def run_function(start, end, functions): + def forward(_x): + for j in range(start, end + 1): + _x = functions[j](_x) + return _x + return forward + + if isinstance(functions, torch.nn.Sequential): + functions = functions.children() + if flatten: + functions = chain.from_iterable(functions) + if not isinstance(functions, (tuple, list)): + functions = tuple(functions) + + num_checkpointed = len(functions) + if skip_last: + num_checkpointed -= 1 + end = -1 + for start in range(0, num_checkpointed, every): + end = min(start + every - 1, num_checkpointed - 1) + x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state) + if skip_last: + return run_function(end + 1, len(functions) - 1, functions)(x) + return x + + +def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'): + prefix_is_tuple = isinstance(prefix, tuple) + if isinstance(module_types, str): + if module_types == 'container': + module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict) + else: + module_types = (nn.Sequential,) + for name, module in named_modules: + if depth and isinstance(module, module_types): + yield from flatten_modules( + module.named_children(), + depth - 1, + prefix=(name,) if prefix_is_tuple else name, + module_types=module_types, + ) + else: + if prefix_is_tuple: + name = prefix + (name,) + yield name, module + else: + if prefix: + name = '.'.join([prefix, name]) + yield name, module diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 32b4eb32..1a53f44d 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -386,13 +386,13 @@ cfg_cls = dict( class HighResolutionModule(nn.Module): - def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + def __init__(self, num_branches, blocks, num_blocks, num_in_chs, num_channels, fuse_method, multi_scale_output=True): super(HighResolutionModule, self).__init__() self._check_branches( - num_branches, blocks, num_blocks, num_inchannels, num_channels) + num_branches, blocks, num_blocks, num_in_chs, num_channels) - self.num_inchannels = num_inchannels + self.num_in_chs = num_in_chs self.fuse_method = fuse_method self.num_branches = num_branches @@ -403,32 +403,32 @@ class HighResolutionModule(nn.Module): self.fuse_layers = self._make_fuse_layers() self.fuse_act = nn.ReLU(False) - def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels): + def _check_branches(self, num_branches, blocks, num_blocks, num_in_chs, num_channels): error_msg = '' if num_branches != len(num_blocks): error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(num_branches, len(num_blocks)) elif num_branches != len(num_channels): error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(num_branches, len(num_channels)) - elif num_branches != len(num_inchannels): - error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(num_branches, len(num_inchannels)) + elif num_branches != len(num_in_chs): + error_msg = 'NUM_BRANCHES({}) <> num_in_chs({})'.format(num_branches, len(num_in_chs)) if error_msg: _logger.error(error_msg) raise ValueError(error_msg) def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): downsample = None - if stride != 1 or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + if stride != 1 or self.num_in_chs[branch_index] != num_channels[branch_index] * block.expansion: downsample = nn.Sequential( nn.Conv2d( - self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion, + self.num_in_chs[branch_index], num_channels[branch_index] * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=_BN_MOMENTUM), ) - layers = [block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)] - self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion + layers = [block(self.num_in_chs[branch_index], num_channels[branch_index], stride, downsample)] + self.num_in_chs[branch_index] = num_channels[branch_index] * block.expansion for i in range(1, num_blocks[branch_index]): - layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) + layers.append(block(self.num_in_chs[branch_index], num_channels[branch_index])) return nn.Sequential(*layers) @@ -444,15 +444,15 @@ class HighResolutionModule(nn.Module): return nn.Identity() num_branches = self.num_branches - num_inchannels = self.num_inchannels + num_in_chs = self.num_in_chs fuse_layers = [] for i in range(num_branches if self.multi_scale_output else 1): fuse_layer = [] for j in range(num_branches): if j > i: fuse_layer.append(nn.Sequential( - nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), - nn.BatchNorm2d(num_inchannels[i], momentum=_BN_MOMENTUM), + nn.Conv2d(num_in_chs[j], num_in_chs[i], 1, 1, 0, bias=False), + nn.BatchNorm2d(num_in_chs[i], momentum=_BN_MOMENTUM), nn.Upsample(scale_factor=2 ** (j - i), mode='nearest'))) elif j == i: fuse_layer.append(nn.Identity()) @@ -460,14 +460,14 @@ class HighResolutionModule(nn.Module): conv3x3s = [] for k in range(i - j): if k == i - j - 1: - num_outchannels_conv3x3 = num_inchannels[i] + num_outchannels_conv3x3 = num_in_chs[i] conv3x3s.append(nn.Sequential( - nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), + nn.Conv2d(num_in_chs[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM))) else: - num_outchannels_conv3x3 = num_inchannels[j] + num_outchannels_conv3x3 = num_in_chs[j] conv3x3s.append(nn.Sequential( - nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), + nn.Conv2d(num_in_chs[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM), nn.ReLU(False))) fuse_layer.append(nn.Sequential(*conv3x3s)) @@ -475,8 +475,8 @@ class HighResolutionModule(nn.Module): return nn.ModuleList(fuse_layers) - def get_num_inchannels(self): - return self.num_inchannels + def get_num_in_chs(self): + return self.num_in_chs def forward(self, x: List[torch.Tensor]): if self.num_branches == 1: @@ -652,7 +652,7 @@ class HighResolutionNet(nn.Module): return nn.Sequential(*layers) - def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): + def _make_stage(self, layer_config, num_in_chs, multi_scale_output=True): num_modules = layer_config['NUM_MODULES'] num_branches = layer_config['NUM_BRANCHES'] num_blocks = layer_config['NUM_BLOCKS'] @@ -665,12 +665,13 @@ class HighResolutionNet(nn.Module): # multi_scale_output is only used last module reset_multi_scale_output = multi_scale_output or i < num_modules - 1 modules.append(HighResolutionModule( - num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method, reset_multi_scale_output) + num_branches, block, num_blocks, num_in_chs, num_channels, fuse_method, reset_multi_scale_output) ) - num_inchannels = modules[-1].get_num_inchannels() + num_in_chs = modules[-1].get_num_in_chs() - return nn.Sequential(*modules), num_inchannels + return nn.Sequential(*modules), num_in_chs + @torch.jit.ignore def init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): @@ -680,6 +681,23 @@ class HighResolutionNet(nn.Module): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^conv[12]|bn[12]', + blocks=r'^(?:layer|stage|transition)(\d+)' if coarse else [ + (r'^layer(\d+).(\d+)', None), + (r'^stage(\d+).(\d+)', None), + (r'^transition(\d+)', (99999,)), + ], + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, "gradient checkpointing not supported" + + @torch.jit.ignore def get_classifier(self): return self.classifier @@ -712,20 +730,24 @@ class HighResolutionNet(nn.Module): # Stages yl = self.stages(x) - - # Classification Head + if self.incre_modules is None or self.downsamp_modules is None: + return yl y = self.incre_modules[0](yl[0]) for i, down in enumerate(self.downsamp_modules): y = self.incre_modules[i + 1](yl[i + 1]) + down(y) y = self.final_layer(y) return y - def forward(self, x): - x = self.forward_features(x) + def forward_head(self, x, pre_logits: bool = False): + # Classification Head x = self.global_pool(x) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) - x = self.classifier(x) + return x if pre_logits else self.classifier(x) + + def forward(self, x): + y = self.forward_features(x) + x = self.forward_head(y) return x diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index d4aced05..fa7b8ec8 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -7,7 +7,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, flatten_modules from .layers import create_classifier from .registry import register_model @@ -300,6 +300,30 @@ class InceptionResnetV2(nn.Module): self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + @torch.jit.ignore + def group_matcher(self, coarse=False): + module_map = {k: i for i, (k, _) in enumerate(flatten_modules(self.named_children(), prefix=()))} + module_map.pop(('classif',)) + + def _matcher(name): + if any([name.startswith(n) for n in ('conv2d_1', 'conv2d_2')]): + return 0 + elif any([name.startswith(n) for n in ('conv2d_3', 'conv2d_4')]): + return 1 + elif any([name.startswith(n) for n in ('block8', 'conv2d_7')]): + return len(module_map) + 1 + else: + for k in module_map.keys(): + if k == tuple(name.split('.')[:len(k)]): + return module_map[k] + return float('inf') + return _matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, "checkpointing not supported" + + @torch.jit.ignore def get_classifier(self): return self.classif @@ -325,12 +349,15 @@ class InceptionResnetV2(nn.Module): x = self.conv2d_7b(x) return x - def forward(self, x): - x = self.forward_features(x) + def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) if self.drop_rate > 0: x = F.dropout(x, p=self.drop_rate, training=self.training) - x = self.classif(x) + return x if pre_logits else self.classif(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) return x diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index eb6fb2cf..e34de657 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -8,7 +8,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, resolve_pretrained_cfg +from .helpers import build_model_with_cfg, resolve_pretrained_cfg, flatten_modules from .registry import register_model from .layers import trunc_normal_, create_classifier, Linear @@ -336,47 +336,57 @@ class InceptionV3(nn.Module): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) + @torch.jit.ignore + def group_matcher(self, coarse=False): + module_map = {k: i for i, (k, _) in enumerate(flatten_modules(self.named_children(), prefix=()))} + module_map.pop(('fc',)) + + def _matcher(name): + if any([name.startswith(n) for n in ('Conv2d_1', 'Conv2d_2')]): + return 0 + elif any([name.startswith(n) for n in ('Conv2d_3', 'Conv2d_4')]): + return 1 + else: + for k in module_map.keys(): + if k == tuple(name.split('.')[:len(k)]): + return module_map[k] + return float('inf') + return _matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + def forward_preaux(self, x): - # N x 3 x 299 x 299 - x = self.Conv2d_1a_3x3(x) - # N x 32 x 149 x 149 - x = self.Conv2d_2a_3x3(x) - # N x 32 x 147 x 147 - x = self.Conv2d_2b_3x3(x) - # N x 64 x 147 x 147 - x = self.Pool1(x) - # N x 64 x 73 x 73 - x = self.Conv2d_3b_1x1(x) - # N x 80 x 73 x 73 - x = self.Conv2d_4a_3x3(x) - # N x 192 x 71 x 71 - x = self.Pool2(x) - # N x 192 x 35 x 35 - x = self.Mixed_5b(x) - # N x 256 x 35 x 35 - x = self.Mixed_5c(x) - # N x 288 x 35 x 35 - x = self.Mixed_5d(x) - # N x 288 x 35 x 35 - x = self.Mixed_6a(x) - # N x 768 x 17 x 17 - x = self.Mixed_6b(x) - # N x 768 x 17 x 17 - x = self.Mixed_6c(x) - # N x 768 x 17 x 17 - x = self.Mixed_6d(x) - # N x 768 x 17 x 17 - x = self.Mixed_6e(x) - # N x 768 x 17 x 17 + x = self.Conv2d_1a_3x3(x) # N x 32 x 149 x 149 + x = self.Conv2d_2a_3x3(x) # N x 32 x 147 x 147 + x = self.Conv2d_2b_3x3(x) # N x 64 x 147 x 147 + x = self.Pool1(x) # N x 64 x 73 x 73 + x = self.Conv2d_3b_1x1(x) # N x 80 x 73 x 73 + x = self.Conv2d_4a_3x3(x) # N x 192 x 71 x 71 + x = self.Pool2(x) # N x 192 x 35 x 35 + x = self.Mixed_5b(x) # N x 256 x 35 x 35 + x = self.Mixed_5c(x) # N x 288 x 35 x 35 + x = self.Mixed_5d(x) # N x 288 x 35 x 35 + x = self.Mixed_6a(x) # N x 768 x 17 x 17 + x = self.Mixed_6b(x) # N x 768 x 17 x 17 + x = self.Mixed_6c(x) # N x 768 x 17 x 17 + x = self.Mixed_6d(x) # N x 768 x 17 x 17 + x = self.Mixed_6e(x) # N x 768 x 17 x 17 return x def forward_postaux(self, x): - x = self.Mixed_7a(x) - # N x 1280 x 8 x 8 - x = self.Mixed_7b(x) - # N x 2048 x 8 x 8 - x = self.Mixed_7c(x) - # N x 2048 x 8 x 8 + x = self.Mixed_7a(x) # N x 1280 x 8 x 8 + x = self.Mixed_7b(x) # N x 2048 x 8 x 8 + x = self.Mixed_7c(x) # N x 2048 x 8 x 8 return x def forward_features(self, x): @@ -384,21 +394,18 @@ class InceptionV3(nn.Module): x = self.forward_postaux(x) return x - def get_classifier(self): - return self.fc - - def reset_classifier(self, num_classes, global_pool='avg'): - self.num_classes = num_classes - self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) - - def forward(self, x): - x = self.forward_features(x) + def forward_head(self, x): x = self.global_pool(x) if self.drop_rate > 0: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.fc(x) return x + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + class InceptionV3Aux(InceptionV3): """InceptionV3 with AuxLogits @@ -416,10 +423,7 @@ class InceptionV3Aux(InceptionV3): def forward(self, x): x, aux = self.forward_features(x) - x = self.global_pool(x) - if self.drop_rate > 0: - x = F.dropout(x, p=self.drop_rate, training=self.training) - x = self.fc(x) + x = self.forward_head(x) return x, aux diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index f95db28e..5f4e208f 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -283,6 +283,18 @@ class InceptionV4(nn.Module): self.global_pool, self.last_linear = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^features\.[012]\.', + blocks=r'^features\.(\d+)' + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore def get_classifier(self): return self.last_linear @@ -294,12 +306,15 @@ class InceptionV4(nn.Module): def forward_features(self, x): return self.features(x) - def forward(self, x): - x = self.forward_features(x) + def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) if self.drop_rate > 0: x = F.dropout(x, p=self.drop_rate, training=self.training) - x = self.last_linear(x) + return x if pre_logits else self.last_linear(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) return x diff --git a/timm/models/layers/attention_pool2d.py b/timm/models/layers/attention_pool2d.py index 66e49b8a..a13a6881 100644 --- a/timm/models/layers/attention_pool2d.py +++ b/timm/models/layers/attention_pool2d.py @@ -7,67 +7,16 @@ https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/cli Hacked together by / Copyright 2021 Ross Wightman """ -import math -from typing import List, Union, Tuple +from typing import Union, Tuple import torch import torch.nn as nn from .helpers import to_2tuple +from .pos_embed import apply_rot_embed, RotaryEmbedding from .weight_init import trunc_normal_ -def rot(x): - return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) - - -def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb): - return x * cos_emb + rot(x) * sin_emb - - -def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): - if isinstance(x, torch.Tensor): - x = [x] - return [t * cos_emb + rot(t) * sin_emb for t in x] - - -class RotaryEmbedding(nn.Module): - """ Rotary position embedding - - NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not - been well tested, and will likely change. It will be moved to its own file. - - The following impl/resources were referenced for this impl: - * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py - * https://blog.eleuther.ai/rotary-embeddings/ - """ - def __init__(self, dim, max_freq=4): - super().__init__() - self.dim = dim - self.register_buffer('bands', 2 ** torch.linspace(0., max_freq - 1, self.dim // 4), persistent=False) - - def get_embed(self, shape: torch.Size, device: torch.device = None, dtype: torch.dtype = None): - """ - NOTE: shape arg should include spatial dim only - """ - device = device or self.bands.device - dtype = dtype or self.bands.dtype - if not isinstance(shape, torch.Size): - shape = torch.Size(shape) - N = shape.numel() - grid = torch.stack(torch.meshgrid( - [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in shape]), dim=-1).unsqueeze(-1) - emb = grid * math.pi * self.bands - sin = emb.sin().reshape(N, -1).repeat_interleave(2, -1) - cos = emb.cos().reshape(N, -1).repeat_interleave(2, -1) - return sin, cos - - def forward(self, x): - # assuming channel-first tensor where spatial dim are >= 2 - sin_emb, cos_emb = self.get_embed(x.shape[2:]) - return apply_rot_embed(x, sin_emb, cos_emb) - - class RotAttentionPool2d(nn.Module): """ Attention based 2D feature pooling w/ rotary (relative) pos embedding. This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. @@ -103,7 +52,6 @@ class RotAttentionPool2d(nn.Module): def forward(self, x): B, _, H, W = x.shape N = H * W - sin_emb, cos_emb = self.pos_embed.get_embed(x.shape[2:]) x = x.reshape(B, -1, N).permute(0, 2, 1) x = torch.cat([x.mean(1, keepdim=True), x], dim=1) @@ -112,6 +60,7 @@ class RotAttentionPool2d(nn.Module): q, k, v = x[0], x[1], x[2] qc, q = q[:, :, :1], q[:, :, 1:] + sin_emb, cos_emb = self.pos_embed.get_embed((H, W)) q = apply_rot_embed(q, sin_emb, cos_emb) q = torch.cat([qc, q], dim=2) diff --git a/timm/models/layers/classifier.py b/timm/models/layers/classifier.py index 798748da..3ac33387 100644 --- a/timm/models/layers/classifier.py +++ b/timm/models/layers/classifier.py @@ -45,10 +45,12 @@ class ClassifierHead(nn.Module): self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() - def forward(self, x): + def forward(self, x, pre_logits: bool = False): x = self.global_pool(x) if self.drop_rate: x = F.dropout(x, p=float(self.drop_rate), training=self.training) - x = self.fc(x) - x = self.flatten(x) - return x + if pre_logits: + return x.flatten(1) + else: + x = self.fc(x) + return self.flatten(x) diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index f48d9a83..42636236 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -97,7 +97,7 @@ def group_rms(x, groups: int = 32, eps: float = 1e-5): class EvoNorm2dB0(nn.Module): - def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_): + def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-3, **_): super().__init__() self.apply_act = apply_act # apply activation (non-linearity) self.momentum = momentum @@ -237,7 +237,7 @@ class EvoNorm2dS0(nn.Module): class EvoNorm2dS0a(EvoNorm2dS0): - def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_): + def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-3, **_): super().__init__( num_features, groups=groups, group_size=group_size, apply_act=apply_act, eps=eps) @@ -290,7 +290,7 @@ class EvoNorm2dS1(nn.Module): class EvoNorm2dS1a(EvoNorm2dS1): def __init__( self, num_features, groups=32, group_size=None, - apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): + apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_): super().__init__( num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) @@ -338,7 +338,7 @@ class EvoNorm2dS2(nn.Module): class EvoNorm2dS2a(EvoNorm2dS2): def __init__( self, num_features, groups=32, group_size=None, - apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): + apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_): super().__init__( num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) diff --git a/timm/models/layers/pos_embed.py b/timm/models/layers/pos_embed.py new file mode 100644 index 00000000..99a122a0 --- /dev/null +++ b/timm/models/layers/pos_embed.py @@ -0,0 +1,207 @@ +import math +from typing import List, Tuple, Optional, Union + +import torch +from torch import nn as nn + + +def pixel_freq_bands( + num_bands: int, + max_freq: float = 224., + linear_bands: bool = True, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +): + if linear_bands: + bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device) + else: + bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device) + return bands * torch.pi + + +def inv_freq_bands( + num_bands: int, + temperature: float = 100000., + step: int = 2, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +) -> torch.Tensor: + inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands)) + return inv_freq + + +def build_sincos2d_pos_embed( + feat_shape: List[int], + dim: int = 64, + temperature: float = 10000., + reverse_coord: bool = False, + interleave_sin_cos: bool = False, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None +) -> torch.Tensor: + """ + + Args: + feat_shape: + dim: + temperature: + reverse_coord: stack grid order W, H instead of H, W + interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos + dtype: + device: + + Returns: + + """ + assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding' + pos_dim = dim // 4 + bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device) + + if reverse_coord: + feat_shape = feat_shape[::-1] # stack W, H instead of H, W + grid = torch.stack( + torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1) + pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0) + # FIXME add support for unflattened spatial dim? + + stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos + pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1) + return pos_emb + + +def build_fourier_pos_embed( + feat_shape: List[int], + bands: Optional[torch.Tensor] = None, + num_bands: int = 64, + max_res: int = 224, + linear_bands: bool = False, + include_grid: bool = False, + concat_out: bool = True, + in_pixels: bool = True, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +) -> List[torch.Tensor]: + if bands is None: + if in_pixels: + bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device) + else: + bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device) + else: + if device is None: + device = bands.device + if dtype is None: + dtype = bands.dtype + + if in_pixels: + grid = torch.stack(torch.meshgrid( + [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1) + else: + grid = torch.stack(torch.meshgrid( + [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1) + grid = grid.unsqueeze(-1) + pos = grid * bands + + pos_sin, pos_cos = pos.sin(), pos.cos() + out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos) + # FIXME torchscript doesn't like multiple return types, probably need to always cat? + if concat_out: + out = torch.cat(out, dim=-1) + return out + + +class FourierEmbed(nn.Module): + + def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False): + super().__init__() + self.max_res = max_res + self.num_bands = num_bands + self.concat_grid = concat_grid + self.keep_spatial = keep_spatial + self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False) + + def forward(self, x): + B, C = x.shape[:2] + feat_shape = x.shape[2:] + emb = build_fourier_pos_embed( + feat_shape, + self.bands, + include_grid=self.concat_grid, + dtype=x.dtype, + device=x.device) + emb = emb.transpose(-1, -2).flatten(len(feat_shape)) + batch_expand = (B,) + (-1,) * (x.ndim - 1) + + # FIXME support nD + if self.keep_spatial: + x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1) + else: + x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1) + x = x.reshape(B, feat_shape.numel(), -1) + + return x + + +def rot(x): + return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) + + +def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb): + return x * cos_emb + rot(x) * sin_emb + + +def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): + if isinstance(x, torch.Tensor): + x = [x] + return [t * cos_emb + rot(t) * sin_emb for t in x] + + +def apply_rot_embed_split(x: torch.Tensor, emb): + split = emb.shape[-1] // 2 + return x * emb[:, :split] + rot(x) * emb[:, split:] + + +def build_rotary_pos_embed( + feat_shape: List[int], + bands: Optional[torch.Tensor] = None, + dim: int = 64, + max_freq: float = 224, + linear_bands: bool = False, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +): + """ + NOTE: shape arg should include spatial dim only + """ + feat_shape = torch.Size(feat_shape) + + sin_emb, cos_emb = build_fourier_pos_embed( + feat_shape, bands=bands, num_bands=dim // 4, max_res=max_freq, linear_bands=linear_bands, + concat_out=False, device=device, dtype=dtype) + N = feat_shape.numel() + sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1) + cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1) + return sin_emb, cos_emb + + +class RotaryEmbedding(nn.Module): + """ Rotary position embedding + + NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not + been well tested, and will likely change. It will be moved to its own file. + + The following impl/resources were referenced for this impl: + * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py + * https://blog.eleuther.ai/rotary-embeddings/ + """ + def __init__(self, dim, max_res=224, linear_bands: bool = False): + super().__init__() + self.dim = dim + self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False) + + def get_embed(self, shape: List[int]): + return build_rotary_pos_embed(shape, self.bands) + + def forward(self, x): + # assuming channel-first tensor where spatial dim are >= 2 + sin_emb, cos_emb = self.get_embed(x.shape[2:]) + return apply_rot_embed(x, sin_emb, cos_emb) diff --git a/timm/models/levit.py b/timm/models/levit.py index 5c21f50f..b1dae17a 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -32,7 +32,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, checkpoint_seq from .layers import to_ntuple, get_act_layer from .vision_transformer import trunc_normal_ from .registry import register_model @@ -65,6 +65,8 @@ default_cfgs = dict( levit_384=_cfg( url='https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth' ), + + levit_256d=_cfg(url='', classifier='head.l'), ) model_cfgs = dict( @@ -78,6 +80,9 @@ model_cfgs = dict( embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)), levit_384=dict( embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)), + + levit_256d=dict( + embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 8, 6)), ) __all__ = ['Levit'] @@ -113,15 +118,21 @@ def levit_384(pretrained=False, use_conv=False, **kwargs): 'levit_384', pretrained=pretrained, use_conv=use_conv, **kwargs) +@register_model +def levit_256d(pretrained=False, use_conv=False, **kwargs): + return create_levit( + 'levit_256d', pretrained=pretrained, use_conv=use_conv, distilled=False, **kwargs) + + class ConvNorm(nn.Sequential): def __init__( - self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): + self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1, resolution=-10000): super().__init__() - self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) - bn = nn.BatchNorm2d(b) - nn.init.constant_(bn.weight, bn_weight_init) - nn.init.constant_(bn.bias, 0) - self.add_module('bn', bn) + self.add_module('c', nn.Conv2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False)) + self.add_module('bn', nn.BatchNorm2d(out_chs)) + + nn.init.constant_(self.bn.weight, bn_weight_init) @torch.no_grad() def fuse(self): @@ -138,13 +149,12 @@ class ConvNorm(nn.Sequential): class LinearNorm(nn.Sequential): - def __init__(self, a, b, bn_weight_init=1, resolution=-100000): + def __init__(self, in_features, out_features, bn_weight_init=1, resolution=-100000): super().__init__() - self.add_module('c', nn.Linear(a, b, bias=False)) - bn = nn.BatchNorm1d(b) - nn.init.constant_(bn.weight, bn_weight_init) - nn.init.constant_(bn.bias, 0) - self.add_module('bn', bn) + self.add_module('c', nn.Linear(in_features, out_features, bias=False)) + self.add_module('bn', nn.BatchNorm1d(out_features)) + + nn.init.constant_(self.bn.weight, bn_weight_init) @torch.no_grad() def fuse(self): @@ -163,14 +173,14 @@ class LinearNorm(nn.Sequential): class NormLinear(nn.Sequential): - def __init__(self, a, b, bias=True, std=0.02): + def __init__(self, in_features, out_features, bias=True, std=0.02): super().__init__() - self.add_module('bn', nn.BatchNorm1d(a)) - l = nn.Linear(a, b, bias=bias) - trunc_normal_(l.weight, std=std) - if bias: - nn.init.constant_(l.bias, 0) - self.add_module('l', l) + self.add_module('bn', nn.BatchNorm1d(in_features)) + self.add_module('l', nn.Linear(in_features, out_features, bias=bias)) + + trunc_normal_(self.l.weight, std=std) + if self.l.bias is not None: + nn.init.constant_(self.l.bias, 0) @torch.no_grad() def fuse(self): @@ -231,34 +241,26 @@ class Attention(nn.Module): def __init__( self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False): super().__init__() - + ln_layer = ConvNorm if use_conv else LinearNorm + self.use_conv = use_conv self.num_heads = num_heads self.scale = key_dim ** -0.5 self.key_dim = key_dim - self.nh_kd = nh_kd = key_dim * num_heads - self.d = int(attn_ratio * key_dim) - self.dh = int(attn_ratio * key_dim) * num_heads - self.attn_ratio = attn_ratio - self.use_conv = use_conv - ln_layer = ConvNorm if self.use_conv else LinearNorm - h = self.dh + nh_kd * 2 - self.qkv = ln_layer(dim, h, resolution=resolution) + self.key_attn_dim = key_dim * num_heads + self.val_dim = int(attn_ratio * key_dim) + self.val_attn_dim = int(attn_ratio * key_dim) * num_heads + + self.qkv = ln_layer(dim, self.val_attn_dim + self.key_attn_dim * 2, resolution=resolution) self.proj = nn.Sequential( act_layer(), - ln_layer(self.dh, dim, bn_weight_init=0, resolution=resolution)) - - points = list(itertools.product(range(resolution), range(resolution))) - N = len(points) - attention_offsets = {} - idxs = [] - for p1 in points: - for p2 in points: - offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) - if offset not in attention_offsets: - attention_offsets[offset] = len(attention_offsets) - idxs.append(attention_offsets[offset]) - self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) - self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) + ln_layer(self.val_attn_dim, dim, bn_weight_init=0, resolution=resolution) + ) + + self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution ** 2)) + pos = torch.stack(torch.meshgrid(torch.arange(resolution), torch.arange(resolution))).flatten(1) + rel_pos = (pos[..., :, None] - pos[..., None, :]).abs() + rel_pos = (rel_pos[0] * resolution) + rel_pos[1] + self.register_buffer('attention_bias_idxs', rel_pos) self.ab = {} @torch.no_grad() @@ -279,7 +281,8 @@ class Attention(nn.Module): def forward(self, x): # x (B,C,H,W) if self.use_conv: B, C, H, W = x.shape - q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2) + q, k, v = self.qkv(x).view( + B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.val_dim], dim=2) attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) @@ -287,8 +290,8 @@ class Attention(nn.Module): x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) else: B, N, C = x.shape - qkv = self.qkv(x) - q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) + q, k, v = self.qkv(x).view( + B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3) q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 3, 1) v = v.permute(0, 2, 1, 3) @@ -296,7 +299,7 @@ class Attention(nn.Module): attn = q @ k * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) - x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.val_attn_dim) x = self.proj(x) return x @@ -306,17 +309,18 @@ class AttentionSubsample(nn.Module): def __init__( self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, - act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False): + act_layer=None, stride=2, resolution=14, resolution_out=7, use_conv=False): super().__init__() + self.stride = stride self.num_heads = num_heads self.scale = key_dim ** -0.5 self.key_dim = key_dim - self.nh_kd = nh_kd = key_dim * num_heads - self.d = int(attn_ratio * key_dim) - self.dh = self.d * self.num_heads - self.attn_ratio = attn_ratio - self.resolution_ = resolution_ - self.resolution_2 = resolution_ ** 2 + self.key_attn_dim = key_dim * num_heads + self.val_dim = int(attn_ratio * key_dim) + self.val_attn_dim = self.val_dim * self.num_heads + self.resolution = resolution + self.resolution_out_area = resolution_out ** 2 + self.use_conv = use_conv if self.use_conv: ln_layer = ConvNorm @@ -325,34 +329,25 @@ class AttentionSubsample(nn.Module): ln_layer = LinearNorm sub_layer = partial(Subsample, resolution=resolution) - h = self.dh + nh_kd - self.kv = ln_layer(in_dim, h, resolution=resolution) + self.kv = ln_layer(in_dim, self.val_attn_dim + self.key_attn_dim, resolution=resolution) self.q = nn.Sequential( sub_layer(stride=stride), - ln_layer(in_dim, nh_kd, resolution=resolution_)) + ln_layer(in_dim, self.key_attn_dim, resolution=resolution_out) + ) self.proj = nn.Sequential( act_layer(), - ln_layer(self.dh, out_dim, resolution=resolution_)) + ln_layer(self.val_attn_dim, out_dim, resolution=resolution_out) + ) + + self.attention_biases = nn.Parameter(torch.zeros(num_heads, self.resolution ** 2)) + k_pos = torch.stack(torch.meshgrid(torch.arange(resolution), torch.arange(resolution))).flatten(1) + q_pos = torch.stack(torch.meshgrid( + torch.arange(0, resolution, step=stride), + torch.arange(0, resolution, step=stride))).flatten(1) + rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs() + rel_pos = (rel_pos[0] * resolution) + rel_pos[1] + self.register_buffer('attention_bias_idxs', rel_pos) - self.stride = stride - self.resolution = resolution - points = list(itertools.product(range(resolution), range(resolution))) - points_ = list(itertools.product(range(resolution_), range(resolution_))) - N = len(points) - N_ = len(points_) - attention_offsets = {} - idxs = [] - for p1 in points_: - for p2 in points: - size = 1 - offset = ( - abs(p1[0] * stride - p2[0] + (size - 1) / 2), - abs(p1[1] * stride - p2[1] + (size - 1) / 2)) - if offset not in attention_offsets: - attention_offsets[offset] = len(attention_offsets) - idxs.append(attention_offsets[offset]) - self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) - self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) self.ab = {} # per-device attention_biases cache @torch.no_grad() @@ -373,24 +368,24 @@ class AttentionSubsample(nn.Module): def forward(self, x): if self.use_conv: B, C, H, W = x.shape - k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2) - q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2) + k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.val_dim], dim=2) + q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_out_area) attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) - x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_) + x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution, self.resolution) else: B, N, C = x.shape - k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3) + k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.val_dim], dim=3) k = k.permute(0, 2, 3, 1) # BHCN v = v.permute(0, 2, 1, 3) # BHNC - q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) + q = self.q(x).view(B, self.resolution_out_area, self.num_heads, self.key_dim).permute(0, 2, 1, 3) attn = q @ k * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) - x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) + x = (attn @ v).transpose(1, 2).reshape(B, -1, self.val_attn_dim) x = self.proj(x) return x @@ -418,35 +413,37 @@ class Levit(nn.Module): down_ops=None, act_layer='hard_swish', attn_act_layer='hard_swish', - distillation=True, use_conv=False, + global_pool='avg', drop_rate=0., drop_path_rate=0.): super().__init__() act_layer = get_act_layer(act_layer) attn_act_layer = get_act_layer(attn_act_layer) + ln_layer = ConvNorm if use_conv else LinearNorm + self.use_conv = use_conv if isinstance(img_size, tuple): # FIXME origin impl passes single img/res dim through whole hierarchy, # not sure this model will be used enough to spend time fixing it. assert img_size[0] == img_size[1] img_size = img_size[0] self.num_classes = num_classes + self.global_pool = global_pool self.num_features = embed_dim[-1] self.embed_dim = embed_dim - N = len(embed_dim) - assert len(depth) == len(num_heads) == N - key_dim = to_ntuple(N)(key_dim) - attn_ratio = to_ntuple(N)(attn_ratio) - mlp_ratio = to_ntuple(N)(mlp_ratio) + self.grad_checkpointing = False + + num_stages = len(embed_dim) + assert len(depth) == len(num_heads) == num_stages + key_dim = to_ntuple(num_stages)(key_dim) + attn_ratio = to_ntuple(num_stages)(attn_ratio) + mlp_ratio = to_ntuple(num_stages)(mlp_ratio) down_ops = down_ops or ( # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) ('Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2), ('Subsample', key_dim[0], embed_dim[1] // key_dim[1], 4, 2, 2), ('',) ) - self.distillation = distillation - self.use_conv = use_conv - ln_layer = ConvNorm if self.use_conv else LinearNorm self.patch_embed = hybrid_backbone or stem_b16(in_chans, embed_dim[0], activation=act_layer) @@ -471,13 +468,13 @@ class Levit(nn.Module): ), drop_path_rate)) if do[0] == 'Subsample': # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) - resolution_ = (resolution - 1) // do[5] + 1 + resolution_out = (resolution - 1) // do[5] + 1 self.blocks.append( AttentionSubsample( *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5], - resolution=resolution, resolution_=resolution_, use_conv=use_conv)) - resolution = resolution_ + resolution=resolution, resolution_out=resolution_out, use_conv=use_conv)) + resolution = resolution_out if do[4] > 0: # mlp_ratio h = int(embed_dim[i + 1] * do[4]) self.blocks.append( @@ -490,52 +487,87 @@ class Levit(nn.Module): # Classifier head self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() - self.head_dist = None - if distillation: - self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() @torch.jit.ignore def no_weight_decay(self): return {x for x in self.state_dict().keys() if 'attention_biases' in x} + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): - if self.head_dist is None: - return self.head - else: - return self.head, self.head_dist + return self.head - def reset_classifier(self, num_classes, global_pool='', distillation=None): + def reset_classifier(self, num_classes, global_pool=None, distillation=None): self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool self.head = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() - if distillation is not None: - self.distillation = distillation - if self.distillation: - self.head_dist = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() - else: - self.head_dist = None def forward_features(self, x): x = self.patch_embed(x) if not self.use_conv: x = x.flatten(2).transpose(1, 2) - x = self.blocks(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) return x + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1) + return x if pre_logits else self.head(x) + def forward(self, x): x = self.forward_features(x) - x = x.mean((-2, -1)) if self.use_conv else x.mean(1) - if self.head_dist is not None: - x, x_dist = self.head(x), self.head_dist(x) - if self.training and not torch.jit.is_scripting(): - return x, x_dist - else: - # during inference, return the average of both classifier predictions - return (x + x_dist) / 2 - else: - x = self.head(x) + x = self.forward_head(x) return x +class LevitDistilled(Levit): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.head_dist = NormLinear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity() + self.distilled_training = False + + @torch.jit.ignore + def get_classifier(self): + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool=None, distillation=None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head = NormLinear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = NormLinear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + @torch.jit.ignore + def set_distilled_training(self, enable=True): + self.distilled_training = enable + + def forward_head(self, x): + if self.global_pool == 'avg': + x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1) + x, x_dist = self.head(x), self.head_dist(x) + if self.distilled_training and self.training and not torch.jit.is_scripting(): + # only return separate classification predictions when training in distilled mode + return x, x_dist + else: + # during standard train/finetune, inference average the classifier predictions + return (x + x_dist) / 2 + + def checkpoint_filter_fn(state_dict, model): if 'model' in state_dict: # For deit models @@ -547,16 +579,14 @@ def checkpoint_filter_fn(state_dict, model): return state_dict -def create_levit(variant, pretrained=False, default_cfg=None, fuse=False, **kwargs): +def create_levit(variant, pretrained=False, distilled=True, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') model_cfg = dict(**model_cfgs[variant], **kwargs) model = build_model_with_cfg( - Levit, variant, pretrained, + LevitDistilled if distilled else Levit, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, **model_cfg) - #if fuse: - # utils.replace_batchnorm(model) return model diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index ca20fbc4..75cdf84b 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -46,7 +46,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply +from .helpers import build_model_with_cfg, named_apply, checkpoint_seq from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple from .registry import register_model @@ -260,10 +260,13 @@ class MlpMixer(nn.Module): drop_path_rate=0., nlhb=False, stem_norm=False, + global_pool='avg', ): super().__init__() self.num_classes = num_classes + self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.grad_checkpointing = False self.stem = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, @@ -279,26 +282,46 @@ class MlpMixer(nn.Module): self.init_weights(nlhb=nlhb) + @torch.jit.ignore def init_weights(self, nlhb=False): head_bias = -math.log(self.num_classes) if nlhb else 0. named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', # stem and embed + blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg') + self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.stem(x) - x = self.blocks(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) x = self.norm(x) return x def forward(self, x): x = self.forward_features(x) - x = x.mean(dim=1) + if self.global_pool == 'avg': + x = x.mean(dim=1) x = self.head(x) return x diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 7171e2ee..79e468a0 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -18,7 +18,7 @@ from .efficientnet_blocks import SqueezeExcite from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .features import FeatureInfo, FeatureHooks -from .helpers import build_model_with_cfg, pretrained_cfg_for_features +from .helpers import build_model_with_cfg, pretrained_cfg_for_features, checkpoint_seq from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, get_norm_act_layer from .registry import register_model @@ -27,7 +27,7 @@ __all__ = ['MobileNetV3', 'MobileNetV3Features'] def _cfg(url='', **kwargs): return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1), + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'conv_stem', 'classifier': 'classifier', @@ -88,7 +88,7 @@ default_cfgs = { test_input_size=(3, 256, 256), crop_pct=0.95), 'fbnetv3_g': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_g_240-0b1df83b.pth', - input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95), + input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95, pool_size=(8, 8)), "lcnet_035": _cfg(), "lcnet_050": _cfg( @@ -134,6 +134,7 @@ class MobileNetV3(nn.Module): self.num_classes = num_classes self.num_features = num_features self.drop_rate = drop_rate + self.grad_checkpointing = False # Stem if not fix_stem: @@ -166,6 +167,18 @@ class MobileNetV3(nn.Module): layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) return nn.Sequential(*layers) + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^conv_stem|bn1', + blocks=r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)' + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): return self.classifier @@ -179,18 +192,28 @@ class MobileNetV3(nn.Module): def forward_features(self, x): x = self.conv_stem(x) x = self.bn1(x) - x = self.blocks(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x, flatten=True) + else: + x = self.blocks(x) + return x + + def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) x = self.conv_head(x) x = self.act2(x) - return x + if pre_logits: + return x.flatten(1) + else: + x = self.flatten(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) def forward(self, x): x = self.forward_features(x) - x = self.flatten(x) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) - return self.classifier(x) + x = self.forward_head(x) + return x class MobileNetV3Features(nn.Module): diff --git a/timm/models/mobilevit.py b/timm/models/mobilevit.py index 8000ed2e..1c55bd1c 100644 --- a/timm/models/mobilevit.py +++ b/timm/models/mobilevit.py @@ -20,6 +20,7 @@ from torch import nn import torch.nn.functional as F from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups +from .fx_features import register_notrace_module from .layers import to_2tuple, make_divisible from .vision_transformer import Block as TransformerBlock from .helpers import build_model_with_cfg @@ -139,6 +140,7 @@ model_cfgs = dict( ) +@register_notrace_module class MobileViTBlock(nn.Module): """ MobileViT block Paper: https://arxiv.org/abs/2110.02178?context=cs.LG @@ -206,7 +208,7 @@ class MobileViTBlock(nn.Module): # Unfold (feature map -> patches) patch_h, patch_w = self.patch_size B, C, H, W = x.shape - new_h, new_w = int(math.ceil(H / patch_h) * patch_h), int(math.ceil(W / patch_w) * patch_w) + new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w num_patches = num_patch_h * num_patch_w # N interpolate = False diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 571312d5..61de5f58 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -407,8 +407,9 @@ class ReductionCell1(nn.Module): class NASNetALarge(nn.Module): """NASNetALarge (6 @ 4032) """ - def __init__(self, num_classes=1000, in_chans=3, stem_size=96, channel_multiplier=2, - num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'): + def __init__( + self, num_classes=1000, in_chans=3, stem_size=96, channel_multiplier=2, + num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'): super(NASNetALarge, self).__init__() self.num_classes = num_classes self.stem_size = stem_size @@ -503,6 +504,23 @@ class NASNetALarge(nn.Module): self.global_pool, self.last_linear = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^conv0|cell_stem_[01]', + blocks=[ + (r'^cell_(\d+)', None), + (r'^reduction_cell_0', (6,)), + (r'^reduction_cell_1', (12,)), + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore def get_classifier(self): return self.last_linear @@ -542,14 +560,18 @@ class NASNetALarge(nn.Module): x = self.act(x_cell_17) return x - def forward(self, x): - x = self.forward_features(x) + def forward_head(self, x): x = self.global_pool(x) if self.drop_rate > 0: x = F.dropout(x, self.drop_rate, training=self.training) x = self.last_linear(x) return x + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + def _create_nasnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( diff --git a/timm/models/nest.py b/timm/models/nest.py index 6b9be873..655cd755 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -26,7 +26,7 @@ from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply +from .helpers import build_model_with_cfg, named_apply, checkpoint_seq from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ from .layers import _assert from .layers import create_conv2d, create_pool2d, to_ntuple @@ -179,6 +179,8 @@ class NestLevel(nn.Module): norm_layer=None, act_layer=None, pad_type=''): super().__init__() self.block_size = block_size + self.grad_checkpointing = False + self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim)) if prev_embed_dim is not None: @@ -204,7 +206,10 @@ class NestLevel(nn.Module): x = x.permute(0, 2, 3, 1) # (B, H', W', C), switch to channels last for transformer x = blockify(x, self.block_size) # (B, T, N, C') x = x + self.pos_embed - x = self.transformer_encoder(x) # (B, T, N, C') + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.transformer_encoder, x) + else: + x = self.transformer_encoder(x) # (B, T, N, C') x = deblockify(x, self.block_size) # (B, H', W', C') # Channel-first for block aggregation, and generally to replicate convnet feature map at each stage return x.permute(0, 3, 1, 2) # (B, C, H', W') @@ -217,10 +222,12 @@ class Nest(nn.Module): - https://arxiv.org/abs/2105.12723 """ - def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512), - num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None, - pad_type='', weight_init='', global_pool='avg'): + def __init__( + self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512), + num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None, + pad_type='', weight_init='', global_pool='avg' + ): """ Args: img_size (int, tuple): input image size @@ -310,6 +317,7 @@ class Nest(nn.Module): self.init_weights(weight_init) + @torch.jit.ignore def init_weights(self, mode=''): assert mode in ('nlhb', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. @@ -321,6 +329,24 @@ class Nest(nn.Module): def no_weight_decay(self): return {f'level.{i}.pos_embed' for i in range(len(self.levels))} + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^patch_embed', # stem and embed + blocks=[ + (r'^levels.(\d+)' if coarse else r'^levels.(\d+).transformer_encoder.(\d+)', None), + (r'^levels.(\d+).(?:pool|pos_embed)', (0,)), + (r'^norm', (99999,)) + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for l in self.levels: + l.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): return self.head @@ -330,22 +356,22 @@ class Nest(nn.Module): self.num_features, self.num_classes, pool_type=global_pool) def forward_features(self, x): - """ x shape (B, C, H, W) - """ x = self.patch_embed(x) x = self.levels(x) # Layer norm done over channel dim only (to NHWC and back) x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x - def forward(self, x): - """ x shape (B, C, H, W) - """ - x = self.forward_features(x) + def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) - return self.head(x) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x def _init_nest_weights(module: nn.Module, name: str = '', head_bias: float = 0.): @@ -364,9 +390,6 @@ def _init_nest_weights(module: nn.Module, name: str = '', head_bias: float = 0.) trunc_normal_(module.weight, std=.02, a=-2, b=2) if module.bias is not None: nn.init.zeros_(module.bias) - elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): - nn.init.zeros_(module.bias) - nn.init.ones_(module.weight) def resize_pos_embed(posemb, posemb_new): diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index dd15ff14..4b79da50 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -27,7 +27,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .fx_features import register_notrace_module -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, checkpoint_seq from .registry import register_model from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ get_act_layer, get_act_fn, get_attn, make_divisible @@ -84,23 +84,6 @@ default_cfgs = dict( nfnet_f7=_dcfg( url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)), - nfnet_f0s=_dcfg( - url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), - nfnet_f1s=_dcfg( - url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)), - nfnet_f2s=_dcfg( - url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)), - nfnet_f3s=_dcfg( - url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)), - nfnet_f4s=_dcfg( - url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)), - nfnet_f5s=_dcfg( - url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)), - nfnet_f6s=_dcfg( - url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)), - nfnet_f7s=_dcfg( - url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)), - nfnet_l0=_dcfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nfnet_l0_ra2-45c6688d.pth', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0), @@ -222,7 +205,7 @@ model_cfgs = dict( dm_nfnet_f5=_dm_nfnet_cfg(depths=(6, 12, 36, 18)), dm_nfnet_f6=_dm_nfnet_cfg(depths=(7, 14, 42, 21)), - # NFNet-F models w/ GELU (I will likely deprecate/remove these models and just keep dm_ ver for GELU) + # NFNet-F models w/ GELU nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)), nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)), nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)), @@ -232,16 +215,6 @@ model_cfgs = dict( nfnet_f6=_nfnet_cfg(depths=(7, 14, 42, 21)), nfnet_f7=_nfnet_cfg(depths=(8, 16, 48, 24)), - # NFNet-F models w/ SiLU (much faster in PyTorch) - nfnet_f0s=_nfnet_cfg(depths=(1, 2, 6, 3), act_layer='silu'), - nfnet_f1s=_nfnet_cfg(depths=(2, 4, 12, 6), act_layer='silu'), - nfnet_f2s=_nfnet_cfg(depths=(3, 6, 18, 9), act_layer='silu'), - nfnet_f3s=_nfnet_cfg(depths=(4, 8, 24, 12), act_layer='silu'), - nfnet_f4s=_nfnet_cfg(depths=(5, 10, 30, 15), act_layer='silu'), - nfnet_f5s=_nfnet_cfg(depths=(6, 12, 36, 18), act_layer='silu'), - nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'), - nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'), - # Experimental 'light' versions of NFNet-F that are little leaner nfnet_l0=_nfnet_cfg( depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25, @@ -477,11 +450,15 @@ class NormFreeNet(nn.Module): * skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput for what it is/does. Approx 8-10% throughput loss. """ - def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, - drop_rate=0., drop_path_rate=0.): + def __init__( + self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, + drop_rate=0., drop_path_rate=0. + ): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate + self.grad_checkpointing = False + assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})." conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d if cfg.gamma_in_act: @@ -568,6 +545,22 @@ class NormFreeNet(nn.Module): if m.bias is not None: nn.init.zeros_(m.bias) + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', + blocks=[ + (r'^stages.(\d+)' if coarse else r'^stages.(\d+).(\d+)', None), + (r'^final_conv', (99999,)) + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): return self.head.fc @@ -576,14 +569,20 @@ class NormFreeNet(nn.Module): def forward_features(self, x): x = self.stem(x) - x = self.stages(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) x = self.final_conv(x) x = self.final_act(x) return x + def forward_head(self, x): + return self.head(x) + def forward(self, x): x = self.forward_features(x) - x = self.head(x) + x = self.forward_head(x) return x @@ -732,78 +731,6 @@ def nfnet_f7(pretrained=False, **kwargs): return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs) -@register_model -def nfnet_f0s(pretrained=False, **kwargs): - """ NFNet-F0 w/ SiLU - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ - return _create_normfreenet('nfnet_f0s', pretrained=pretrained, **kwargs) - - -@register_model -def nfnet_f1s(pretrained=False, **kwargs): - """ NFNet-F1 w/ SiLU - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ - return _create_normfreenet('nfnet_f1s', pretrained=pretrained, **kwargs) - - -@register_model -def nfnet_f2s(pretrained=False, **kwargs): - """ NFNet-F2 w/ SiLU - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ - return _create_normfreenet('nfnet_f2s', pretrained=pretrained, **kwargs) - - -@register_model -def nfnet_f3s(pretrained=False, **kwargs): - """ NFNet-F3 w/ SiLU - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ - return _create_normfreenet('nfnet_f3s', pretrained=pretrained, **kwargs) - - -@register_model -def nfnet_f4s(pretrained=False, **kwargs): - """ NFNet-F4 w/ SiLU - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ - return _create_normfreenet('nfnet_f4s', pretrained=pretrained, **kwargs) - - -@register_model -def nfnet_f5s(pretrained=False, **kwargs): - """ NFNet-F5 w/ SiLU - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ - return _create_normfreenet('nfnet_f5s', pretrained=pretrained, **kwargs) - - -@register_model -def nfnet_f6s(pretrained=False, **kwargs): - """ NFNet-F6 w/ SiLU - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ - return _create_normfreenet('nfnet_f6s', pretrained=pretrained, **kwargs) - - -@register_model -def nfnet_f7s(pretrained=False, **kwargs): - """ NFNet-F7 w/ SiLU - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ - return _create_normfreenet('nfnet_f7s', pretrained=pretrained, **kwargs) - - @register_model def nfnet_l0(pretrained=False, **kwargs): """ NFNet-L0b w/ SiLU diff --git a/timm/models/pit.py b/timm/models/pit.py index b0788c1e..bef625ab 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -148,9 +148,10 @@ class PoolingVisionTransformer(nn.Module): - https://arxiv.org/abs/2103.16302 """ def __init__(self, img_size, patch_size, stride, base_dims, depth, heads, - mlp_ratio, num_classes=1000, in_chans=3, distilled=False, + mlp_ratio, num_classes=1000, in_chans=3, distilled=False, global_pool='token', attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0): super(PoolingVisionTransformer, self).__init__() + assert global_pool in ('token',) padding = 0 img_size = to_2tuple(img_size) @@ -161,6 +162,7 @@ class PoolingVisionTransformer(nn.Module): self.base_dims = base_dims self.heads = heads self.num_classes = num_classes + self.global_pool = global_pool self.num_tokens = 2 if distilled else 1 self.patch_size = patch_size @@ -205,13 +207,17 @@ class PoolingVisionTransformer(nn.Module): def no_weight_decay(self): return {'pos_embed', 'cls_token'} + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + def get_classifier(self): if self.head_dist is not None: return self.head, self.head_dist else: return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() if self.head_dist is not None: diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 4aef89f4..81067845 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -296,6 +296,15 @@ class PNASNet5Large(nn.Module): self.global_pool, self.last_linear = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict(stem=r'^conv_0|cell_stem_[01]', blocks=r'^cell_(\d+)') + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore def get_classifier(self): return self.last_linear @@ -323,12 +332,15 @@ class PNASNet5Large(nn.Module): x = self.act(x_cell_11) return x - def forward(self, x): - x = self.forward_features(x) + def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) if self.drop_rate > 0: x = F.dropout(x, self.drop_rate, training=self.training) - x = self.last_linear(x) + return x if pre_logits else self.last_linear(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) return x diff --git a/timm/models/poolformer.py b/timm/models/poolformer.py new file mode 100644 index 00000000..3cf6b1a3 --- /dev/null +++ b/timm/models/poolformer.py @@ -0,0 +1,322 @@ +""" PoolFormer implementation + +Paper: `PoolFormer: MetaFormer is Actually What You Need for Vision` - https://arxiv.org/abs/2111.11418 + +Code adapted from official impl at https://github.com/sail-sg/poolformer, original copyright in comment below + +Modifications and additions for timm by / Copyright 2022, Ross Wightman +""" +# Copyright 2021 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import copy +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, checkpoint_seq +from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .95, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = dict( + poolformer_s12=_cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s12.pth.tar', + crop_pct=0.9), + poolformer_s24=_cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s24.pth.tar', + crop_pct=0.9), + poolformer_s36=_cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s36.pth.tar', + crop_pct=0.9), + poolformer_m36=_cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m36.pth.tar', + crop_pct=0.95), + poolformer_m48=_cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m48.pth.tar', + crop_pct=0.95), +) + + +class PatchEmbed(nn.Module): + """ Patch Embedding that is implemented by a layer of conv. + Input: tensor in shape [B, C, H, W] + Output: tensor in shape [B, C, H/stride, W/stride] + """ + + def __init__(self, in_chs=3, embed_dim=768, patch_size=16, stride=16, padding=0, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + stride = to_2tuple(stride) + padding = to_2tuple(padding) + self.proj = nn.Conv2d(in_chs, embed_dim, kernel_size=patch_size, stride=stride, padding=padding) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + + +class GroupNorm1(nn.GroupNorm): + """ Group Normalization with 1 group. + Input: tensor in shape [B, C, H, W] + """ + + def __init__(self, num_channels, **kwargs): + super().__init__(1, num_channels, **kwargs) + + +class Pooling(nn.Module): + def __init__(self, pool_size=3): + super().__init__() + self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False) + + def forward(self, x): + return self.pool(x) - x + + +class PoolFormerBlock(nn.Module): + """ + Args: + dim: embedding dim + pool_size: pooling size + mlp_ratio: mlp expansion ratio + act_layer: activation + norm_layer: normalization + drop: dropout rate + drop path: Stochastic Depth, refer to https://arxiv.org/abs/1603.09382 + use_layer_scale, --layer_scale_init_value: LayerScale, refer to https://arxiv.org/abs/2103.17239 + """ + + def __init__( + self, dim, pool_size=3, mlp_ratio=4., + act_layer=nn.GELU, norm_layer=GroupNorm1, + drop=0., drop_path=0., layer_scale_init_value=1e-5): + + super().__init__() + + self.norm1 = norm_layer(dim) + self.token_mixer = Pooling(pool_size=pool_size) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = ConvMlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + if layer_scale_init_value: + self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + else: + self.layer_scale_1 = None + self.layer_scale_2 = None + + def forward(self, x): + if self.layer_scale_1 is not None: + x = x + self.drop_path1(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.token_mixer(self.norm1(x))) + x = x + self.drop_path2(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))) + else: + x = x + self.drop_path1(self.token_mixer(self.norm1(x))) + x = x + self.drop_path2(self.mlp(self.norm2(x))) + return x + + +def basic_blocks( + dim, index, layers, + pool_size=3, mlp_ratio=4., + act_layer=nn.GELU, norm_layer=GroupNorm1, + drop_rate=.0, drop_path_rate=0., + layer_scale_init_value=1e-5, +): + """ generate PoolFormer blocks for a stage """ + blocks = [] + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) + blocks.append(PoolFormerBlock( + dim, pool_size=pool_size, mlp_ratio=mlp_ratio, + act_layer=act_layer, norm_layer=norm_layer, + drop=drop_rate, drop_path=block_dpr, + layer_scale_init_value=layer_scale_init_value, + )) + blocks = nn.Sequential(*blocks) + return blocks + + +class PoolFormer(nn.Module): + """ PoolFormer + """ + + def __init__( + self, + layers, + embed_dims=(64, 128, 320, 512), + mlp_ratios=(4, 4, 4, 4), + downsamples=(True, True, True, True), + pool_size=3, + in_chans=3, + num_classes=1000, + global_pool='avg', + norm_layer=GroupNorm1, + act_layer=nn.GELU, + in_patch_size=7, + in_stride=4, + in_pad=2, + down_patch_size=3, + down_stride=2, + down_pad=1, + drop_rate=0., drop_path_rate=0., + layer_scale_init_value=1e-5, + **kwargs): + + super().__init__() + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = embed_dims[-1] + self.grad_checkpointing = False + + self.patch_embed = PatchEmbed( + patch_size=in_patch_size, stride=in_stride, padding=in_pad, + in_chs=in_chans, embed_dim=embed_dims[0]) + + # set the main block in network + network = [] + for i in range(len(layers)): + network.append(basic_blocks( + embed_dims[i], i, layers, + pool_size=pool_size, mlp_ratio=mlp_ratios[i], + act_layer=act_layer, norm_layer=norm_layer, + drop_rate=drop_rate, drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value) + ) + if i < len(layers) - 1 and (downsamples[i] or embed_dims[i] != embed_dims[i + 1]): + # downsampling between stages + network.append(PatchEmbed( + in_chs=embed_dims[i], embed_dim=embed_dims[i + 1], + patch_size=down_patch_size, stride=down_stride, padding=down_pad) + ) + + self.network = nn.Sequential(*network) + self.norm = norm_layer(self.num_features) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + # init for classification + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^patch_embed', # stem and embed + blocks=[ + (r'^network\.(\d+)\.(\d+)', None), + (r'^network\.(\d+)', (0,)), + (r'^norm', (99999,)) + ], + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + x = self.network(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + x = x.mean([-2, -1]) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_poolformer(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + model = build_model_with_cfg(PoolFormer, variant, pretrained, **kwargs) + return model + + +@register_model +def poolformer_s12(pretrained=False, **kwargs): + """ PoolFormer-S12 model, Params: 12M """ + model = _create_poolformer('poolformer_s12', pretrained=pretrained, layers=(2, 2, 6, 2), **kwargs) + return model + + +@register_model +def poolformer_s24(pretrained=False, **kwargs): + """ PoolFormer-S24 model, Params: 21M """ + model = _create_poolformer('poolformer_s24', pretrained=pretrained, layers=(4, 4, 12, 4), **kwargs) + return model + + +@register_model +def poolformer_s36(pretrained=False, **kwargs): + """ PoolFormer-S36 model, Params: 31M """ + model = _create_poolformer( + 'poolformer_s36', pretrained=pretrained, layers=(6, 6, 18, 6), layer_scale_init_value=1e-6, **kwargs) + return model + + +@register_model +def poolformer_m36(pretrained=False, **kwargs): + """ PoolFormer-M36 model, Params: 56M """ + layers = (6, 6, 18, 6) + embed_dims = (96, 192, 384, 768) + model = _create_poolformer( + 'poolformer_m36', pretrained=pretrained, layers=layers, embed_dims=embed_dims, + layer_scale_init_value=1e-6, **kwargs) + return model + + +@register_model +def poolformer_m48(pretrained=False, **kwargs): + """ PoolFormer-M48 model, Params: 73M """ + layers = (8, 8, 24, 8) + embed_dims = (96, 192, 384, 768) + model = _create_poolformer( + 'poolformer_m48', pretrained=pretrained, layers=layers, embed_dims=embed_dims, + layer_scale_init_value=1e-6, **kwargs) + return model diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 5497b74b..653ac6d5 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -19,10 +19,11 @@ from functools import partial from typing import Optional, Union, Callable import numpy as np +import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, named_apply +from .helpers import build_model_with_cfg, named_apply, checkpoint_seq from .layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct from .layers import get_act_layer, get_norm_act_layer, create_conv2d from .registry import register_model @@ -80,14 +81,13 @@ model_cfgs = dict( regnety_040s_gn=RegNetCfg( w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25, act_layer='silu', norm_layer=partial(GroupNormAct, group_size=16)), + # regnetv = 'preact regnet y' regnetv_040=RegNetCfg( depth=22, w0=96, wa=31.41, wm=2.24, group_size=64, se_ratio=0.25, preact=True, act_layer='silu'), - # regnetw = 'preact regnet z' - regnetw_040=RegNetCfg( - depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25, - downsample=None, preact=True, num_features=1536, act_layer='silu', - ), + regnetv_064=RegNetCfg( + depth=25, w0=112, wa=33.22, wm=2.27, group_size=72, se_ratio=0.25, preact=True, act_layer='silu', + downsample='avg'), # RegNet-Z (unverified) regnetz_005=RegNetCfg( @@ -95,6 +95,10 @@ model_cfgs = dict( downsample=None, linear_out=True, num_features=1024, act_layer='silu', ), regnetz_040=RegNetCfg( + depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25, + downsample=None, linear_out=True, num_features=0, act_layer='silu', + ), + regnetz_040h=RegNetCfg( depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25, downsample=None, linear_out=True, num_features=1536, act_layer='silu', ), @@ -144,10 +148,11 @@ default_cfgs = dict( regnety_040s_gn=_cfg(url=''), regnetv_040=_cfg(url='', first_conv='stem'), - regnetw_040=_cfg(url='', first_conv='stem', input_size=(3, 256, 256), pool_size=(8, 8)), + regnetv_064=_cfg(url='', first_conv='stem'), regnetz_005=_cfg(url=''), regnetz_040=_cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + regnetz_040h=_cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), ) @@ -326,6 +331,8 @@ class RegStage(nn.Module): self, depth, in_chs, out_chs, stride, dilation, drop_path_rates=None, block_fn=Bottleneck, **block_kwargs): super(RegStage, self).__init__() + self.grad_checkpointing = False + first_dilation = 1 if dilation in (1, 2) else 2 for i in range(depth): block_stride = stride if i == 0 else 1 @@ -341,8 +348,11 @@ class RegStage(nn.Module): first_dilation = dilation def forward(self, x): - for block in self.children(): - x = block(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.children(), x) + else: + for block in self.children(): + x = block(x) return x @@ -375,6 +385,7 @@ class RegNet(nn.Module): curr_stride = 2 per_stage_args, common_args = self._get_stage_args( cfg, output_stride=output_stride, drop_path_rate=drop_path_rate) + assert len(per_stage_args) == 4 block_fn = PreBottleneck if cfg.preact else Bottleneck for i, stage_args in enumerate(per_stage_args): stage_name = "s{}".format(i + 1) @@ -429,6 +440,19 @@ class RegNet(nn.Module): act_layer=cfg.act_layer, norm_layer=cfg.norm_layer) return per_stage_args, common_args + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=r'^stages.(\d+)' if coarse else r'^stages.(\d+).blocks.(\d+)', + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in list(self.children())[1:-1]: + s.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): return self.head.fc @@ -436,13 +460,20 @@ class RegNet(nn.Module): self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) def forward_features(self, x): - for block in list(self.children())[:-1]: - x = block(x) + x = self.stem(x) + x = self.s1(x) + x = self.s2(x) + x = self.s3(x) + x = self.s4(x) + x = self.final_conv(x) return x + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) + def forward(self, x): - for block in self.children(): - x = block(x) + x = self.forward_features(x) + x = self.forward_head(x) return x @@ -634,9 +665,9 @@ def regnetv_040(pretrained=False, **kwargs): @register_model -def regnetw_040(pretrained=False, **kwargs): +def regnetv_064(pretrained=False, **kwargs): """""" - return _create_regnet('regnetw_040', pretrained, **kwargs) + return _create_regnet('regnetv_064', pretrained, **kwargs) @register_model @@ -655,3 +686,12 @@ def regnetz_040(pretrained=False, **kwargs): but it's not clear it is equivalent to paper model as not detailed in the paper. """ return _create_regnet('regnetz_040', pretrained, zero_init_last=False, **kwargs) + + +@register_model +def regnetz_040h(pretrained=False, **kwargs): + """RegNetZ-4.0GF + NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py + but it's not clear it is equivalent to paper model as not detailed in the paper. + """ + return _create_regnet('regnetz_040h', pretrained, zero_init_last=False, **kwargs) diff --git a/timm/models/res2net.py b/timm/models/res2net.py index 109fee1f..6c2fd1bf 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -50,9 +50,10 @@ class Bottle2neck(nn.Module): """ expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=26, scale=4, dilation=1, first_dilation=None, - act_layer=nn.ReLU, norm_layer=None, attn_layer=None, **_): + def __init__( + self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=26, scale=4, dilation=1, first_dilation=None, + act_layer=nn.ReLU, norm_layer=None, attn_layer=None, **_): super(Bottle2neck, self).__init__() self.scale = scale self.is_first = stride > 1 or downsample is not None @@ -87,7 +88,7 @@ class Bottle2neck(nn.Module): self.relu = act_layer(inplace=True) self.downsample = downsample - def zero_init_last_bn(self): + def zero_init_last(self): nn.init.zeros_(self.bn3.weight) def forward(self, x): @@ -110,8 +111,7 @@ class Bottle2neck(nn.Module): sp = self.relu(sp) spo.append(sp) if self.scale > 1: - if self.pool is not None: - # self.is_first == True, None check for torchscript + if self.pool is not None: # self.is_first == True, None check for torchscript spo.append(self.pool(spx[-1])) else: spo.append(spx[-1]) diff --git a/timm/models/resnest.py b/timm/models/resnest.py index 7bbe58e0..735b91a2 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -57,10 +57,11 @@ class ResNestBottleneck(nn.Module): # pylint: disable=unused-argument expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, - radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + def __init__( + self, inplanes, planes, stride=1, downsample=None, + radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): super(ResNestBottleneck, self).__init__() assert reduce_first == 1 # not supported assert attn_layer is None # not supported @@ -102,7 +103,7 @@ class ResNestBottleneck(nn.Module): self.act3 = act_layer(inplace=True) self.downsample = downsample - def zero_init_last_bn(self): + def zero_init_last(self): nn.init.zeros_(self.bn3.weight) def forward(self, x): diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 8d3d9043..4428f2ca 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -15,7 +15,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, checkpoint_seq from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, create_classifier from .registry import register_model @@ -105,7 +105,9 @@ default_cfgs = { first_conv='conv1.0'), 'resnext101_32x4d': _cfg(url=''), 'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'), - 'resnext101_64x4d': _cfg(url=''), + 'resnext101_64x4d': _cfg( + url='', + interpolation='bicubic', crop_pct=1.0, test_input_size=(3, 288, 288)), 'tv_resnext50_32x4d': _cfg(url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth'), # ResNeXt models - Weakly Supervised Pretraining on Instagram Hashtags @@ -345,7 +347,7 @@ class BasicBlock(nn.Module): self.dilation = dilation self.drop_path = drop_path - def zero_init_last_bn(self): + def zero_init_last(self): nn.init.zeros_(self.bn2.weight) def forward(self, x): @@ -411,7 +413,7 @@ class Bottleneck(nn.Module): self.dilation = dilation self.drop_path = drop_path - def zero_init_last_bn(self): + def zero_init_last(self): nn.init.zeros_(self.bn3.weight) def forward(self, x): @@ -600,12 +602,13 @@ class ResNet(nn.Module): cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False, output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0., - drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): + drop_block_rate=0., global_pool='avg', zero_init_last=True, block_args=None): + super(ResNet, self).__init__() block_args = block_args or dict() assert output_stride in (8, 16, 32) self.num_classes = num_classes self.drop_rate = drop_rate - super(ResNet, self).__init__() + self.grad_checkpointing = False # Stem deep_stem = 'deep' in stem_type @@ -632,7 +635,7 @@ class ResNet(nn.Module): if replace_stem_pool: self.maxpool = nn.Sequential(*filter(None, [ nn.Conv2d(inplanes, inplanes, 3, stride=1 if aa_layer else 2, padding=1, bias=False), - create_aa(aa_layer, channels=inplanes, stride=2), + create_aa(aa_layer, channels=inplanes, stride=2) if aa_layer is not None else None, norm_layer(inplanes), act_layer(inplace=True) ])) @@ -662,22 +665,33 @@ class ResNet(nn.Module): self.num_features = 512 * block.expansion self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) - self.init_weights(zero_init_last_bn=zero_init_last_bn) + self.init_weights(zero_init_last=zero_init_last) - def init_weights(self, zero_init_last_bn=True): + @torch.jit.ignore + def init_weights(self, zero_init_last=True): for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) - if zero_init_last_bn: + if zero_init_last: for m in self.modules(): - if hasattr(m, 'zero_init_last_bn'): - m.zero_init_last_bn() + if hasattr(m, 'zero_init_last'): + m.zero_init_last() - def get_classifier(self): - return self.fc + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict(stem=r'^conv1|bn1|maxpool', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)') + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self, name_only=False): + return 'fc' if name_only else self.fc def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes @@ -689,10 +703,13 @@ class ResNet(nn.Module): x = self.act1(x) x = self.maxpool(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq([self.layer1, self.layer2, self.layer3, self.layer4], x, flatten=True) + else: + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) return x def forward(self, x): diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index dbf3e9cc..b13d0960 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -36,10 +36,9 @@ import torch.nn as nn from functools import partial from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, named_apply, adapt_input_conv +from .helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq from .registry import register_model -from .layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0,\ - EvoNorm2dS1, EvoNorm2dS2, FilterResponseNormTlu2d, FilterResponseNormAct2d,\ +from .layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, EvoNorm2dS1, FilterResponseNormTlu2d,\ ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d @@ -280,9 +279,10 @@ class DownsampleAvg(nn.Module): class ResNetStage(nn.Module): """ResNet Stage.""" - def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1, - avg_down=False, block_dpr=None, block_fn=PreActBottleneck, - act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs): + def __init__( + self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1, + avg_down=False, block_dpr=None, block_fn=PreActBottleneck, + act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs): super(ResNetStage, self).__init__() first_dilation = 1 if dilation in (1, 2) else 2 layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer) @@ -397,7 +397,9 @@ class ResNetV2(nn.Module): self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True) self.init_weights(zero_init_last=zero_init_last) + self.grad_checkpointing = False + @torch.jit.ignore def init_weights(self, zero_init_last=True): named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) @@ -405,6 +407,22 @@ class ResNetV2(nn.Module): def load_pretrained(self, checkpoint_path, prefix='resnet/'): _load_weights(self, checkpoint_path, prefix) + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', + blocks=r'^stages.(\d+)' if coarse else [ + (r'^stages.(\d+).blocks.(\d+)', None), + (r'^norm', (99999,)) + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): return self.head.fc @@ -415,13 +433,19 @@ class ResNetV2(nn.Module): def forward_features(self, x): x = self.stem(x) - x = self.stages(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x, flatten=True) + else: + x = self.stages(x) x = self.norm(x) return x + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) + def forward(self, x): x = self.forward_features(x) - x = self.head(x) + x = self.forward_head(x) return x diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index df1e0afe..902d344f 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -16,7 +16,7 @@ from functools import partial from math import ceil from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, checkpoint_seq from .layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule from .registry import register_model from .efficientnet_builder import efficientnet_init_weights @@ -54,8 +54,9 @@ SEWithNorm = partial(SEModule, norm_layer=nn.BatchNorm2d) class LinearBottleneck(nn.Module): - def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1, - act_layer='swish', dw_act_layer='relu6', drop_path=None): + def __init__( + self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1, + act_layer='swish', dw_act_layer='relu6', drop_path=None): super(LinearBottleneck, self).__init__() self.use_shortcut = stride == 1 and in_chs <= out_chs self.in_channels = in_chs @@ -143,12 +144,15 @@ def _build_blocks( class ReXNetV1(nn.Module): - def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, - initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, se_ratio=1/12., - ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_rate=0.2, drop_path_rate=0.): + def __init__( + self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, + initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, se_ratio=1/12., + ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_rate=0.2, drop_path_rate=0. + ): super(ReXNetV1, self).__init__() - self.drop_rate = drop_rate self.num_classes = num_classes + self.drop_rate = drop_rate + self.grad_checkpointing = False assert output_stride == 32 # FIXME support dilation stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32 @@ -165,6 +169,19 @@ class ReXNetV1(nn.Module): efficientnet_init_weights(self) + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', + blocks=r'^features.(\d+)', + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): return self.head.fc @@ -173,12 +190,18 @@ class ReXNetV1(nn.Module): def forward_features(self, x): x = self.stem(x) - x = self.features(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.features, x, flatten=True) + else: + x = self.features(x) return x + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) + def forward(self, x): x = self.forward_features(x) - x = self.head(x) + x = self.forward_head(x) return x diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 5bc1b78a..1a9ac929 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -174,6 +174,19 @@ class SelecSLS(nn.Module): nn.init.constant_(m.weight, 1.) nn.init.constant_(m.bias, 0.) + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=r'^features\.(\d+)', + blocks_head=r'^head' + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore def get_classifier(self): return self.fc @@ -187,12 +200,15 @@ class SelecSLS(nn.Module): x = self.head(self.from_seq(x)) return x - def forward(self, x): - x = self.forward_features(x) + def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) - x = self.fc(x) + return x if pre_logits else self.fc(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) return x diff --git a/timm/models/senet.py b/timm/models/senet.py index d07f01ad..7a7a5e1c 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -14,6 +14,7 @@ support for extras like dilation, switchable BN/activations, feature extraction, import math from collections import OrderedDict +import torch import torch.nn as nn import torch.nn.functional as F @@ -120,8 +121,7 @@ class SEBottleneck(Bottleneck): """ expansion = 4 - def __init__(self, inplanes, planes, groups, reduction, stride=1, - downsample=None): + def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): super(SEBottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes * 2) @@ -129,8 +129,7 @@ class SEBottleneck(Bottleneck): planes * 2, planes * 4, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) self.bn2 = nn.BatchNorm2d(planes * 4) - self.conv3 = nn.Conv2d( - planes * 4, planes * 4, kernel_size=1, bias=False) + self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=True) self.se_module = SEModule(planes * 4, reduction=reduction) @@ -146,14 +145,11 @@ class SEResNetBottleneck(Bottleneck): """ expansion = 4 - def __init__(self, inplanes, planes, groups, reduction, stride=1, - downsample=None): + def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): super(SEResNetBottleneck, self).__init__() - self.conv1 = nn.Conv2d( - inplanes, planes, kernel_size=1, bias=False, stride=stride) + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, stride=stride) self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d( - planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) @@ -169,15 +165,12 @@ class SEResNeXtBottleneck(Bottleneck): """ expansion = 4 - def __init__(self, inplanes, planes, groups, reduction, stride=1, - downsample=None, base_width=4): + def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None, base_width=4): super(SEResNeXtBottleneck, self).__init__() width = math.floor(planes * (base_width / 64)) * groups - self.conv1 = nn.Conv2d( - inplanes, width, kernel_size=1, bias=False, stride=1) + self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, stride=1) self.bn1 = nn.BatchNorm2d(width) - self.conv2 = nn.Conv2d( - width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) + self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) self.bn2 = nn.BatchNorm2d(width) self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) @@ -192,11 +185,9 @@ class SEResNetBlock(nn.Module): def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): super(SEResNetBlock, self).__init__() - self.conv1 = nn.Conv2d( - inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False) + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False) self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d( - planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.se_module = SEModule(planes, reduction=reduction) @@ -225,9 +216,10 @@ class SEResNetBlock(nn.Module): class SENet(nn.Module): - def __init__(self, block, layers, groups, reduction, drop_rate=0.2, - in_chans=3, inplanes=64, input_3x3=False, downsample_kernel_size=1, - downsample_padding=0, num_classes=1000, global_pool='avg'): + def __init__( + self, block, layers, groups, reduction, drop_rate=0.2, + in_chans=3, inplanes=64, input_3x3=False, downsample_kernel_size=1, + downsample_padding=0, num_classes=1000, global_pool='avg'): """ Parameters ---------- @@ -366,6 +358,16 @@ class SENet(nn.Module): return nn.Sequential(*layers) + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict(stem=r'^layer0', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+).(\d+)') + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore def get_classifier(self): return self.last_linear @@ -383,16 +385,15 @@ class SENet(nn.Module): x = self.layer4(x) return x - def logits(self, x): + def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) - x = self.last_linear(x) - return x + return x if pre_logits else self.last_linear(x) def forward(self, x): x = self.forward_features(x) - x = self.logits(x) + x = self.forward_head(x) return x diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 0ca38b87..fb9f063a 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -46,9 +46,10 @@ default_cfgs = { class SelectiveKernelBasic(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, - sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + def __init__( + self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): super(SelectiveKernelBasic, self).__init__() sk_kwargs = sk_kwargs or {} @@ -69,7 +70,7 @@ class SelectiveKernelBasic(nn.Module): self.downsample = downsample self.drop_path = drop_path - def zero_init_last_bn(self): + def zero_init_last(self): nn.init.zeros_(self.conv2.bn.weight) def forward(self, x): @@ -90,10 +91,10 @@ class SelectiveKernelBasic(nn.Module): class SelectiveKernelBottleneck(nn.Module): expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, - drop_block=None, drop_path=None): + def __init__( + self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, sk_kwargs=None, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): super(SelectiveKernelBottleneck, self).__init__() sk_kwargs = sk_kwargs or {} @@ -113,7 +114,7 @@ class SelectiveKernelBottleneck(nn.Module): self.downsample = downsample self.drop_path = drop_path - def zero_init_last_bn(self): + def zero_init_last(self): nn.init.zeros_(self.conv3.bn.weight) def forward(self, x): @@ -146,7 +147,7 @@ def skresnet18(pretrained=False, **kwargs): sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True) model_args = dict( block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs), - zero_init_last_bn=False, **kwargs) + zero_init_last=False, **kwargs) return _create_skresnet('skresnet18', pretrained, **model_args) @@ -160,7 +161,7 @@ def skresnet34(pretrained=False, **kwargs): sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True) model_args = dict( block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs), - zero_init_last_bn=False, **kwargs) + zero_init_last=False, **kwargs) return _create_skresnet('skresnet34', pretrained, **model_args) @@ -174,7 +175,7 @@ def skresnet50(pretrained=False, **kwargs): sk_kwargs = dict(split_input=True) model_args = dict( block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs), - zero_init_last_bn=False, **kwargs) + zero_init_last=False, **kwargs) return _create_skresnet('skresnet50', pretrained, **model_args) @@ -188,7 +189,7 @@ def skresnet50d(pretrained=False, **kwargs): sk_kwargs = dict(split_input=True) model_args = dict( block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, - block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last=False, **kwargs) return _create_skresnet('skresnet50d', pretrained, **model_args) @@ -200,6 +201,6 @@ def skresnext50_32x4d(pretrained=False, **kwargs): sk_kwargs = dict(rd_ratio=1/16, rd_divisor=32, split_input=False) model_args = dict( block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, - block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last=False, **kwargs) return _create_skresnet('skresnext50_32x4d', pretrained, **model_args) diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index cd571a0d..79d36b65 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -19,14 +19,13 @@ from typing import Optional import torch import torch.nn as nn -import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, named_apply -from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert +from .helpers import build_model_with_cfg, named_apply, checkpoint_seq +from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert from .registry import register_model -from .vision_transformer import checkpoint_filter_fn, _init_vit_weights +from .vision_transformer import checkpoint_filter_fn, get_init_weights_vit _logger = logging.getLogger(__name__) @@ -85,6 +84,15 @@ default_cfgs = { url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth', num_classes=21841), + 'swin_s3_tiny_224': _cfg( + url='https://github.com/silent-chen/AutoFormerV2-model-zoo/releases/download/v1.0.0/S3-T.pth' + ), + 'swin_s3_small_224': _cfg( + url='https://github.com/silent-chen/AutoFormerV2-model-zoo/releases/download/v1.0.0/S3-S.pth' + ), + 'swin_s3_base_224': _cfg( + url='https://github.com/silent-chen/AutoFormerV2-model-zoo/releases/download/v1.0.0/S3-B.pth' + ) } @@ -121,53 +129,64 @@ def window_reverse(windows, window_size: int, H: int, W: int): return x +def get_relative_position_index(win_h, win_w): + # get pair-wise relative position index for each token inside the window + coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += win_h - 1 # shift to start from 0 + relative_coords[:, :, 1] += win_w - 1 + relative_coords[:, :, 0] *= 2 * win_w - 1 + return relative_coords.sum(-1) # Wh*Ww, Wh*Ww + + class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. + head_dim (int): Number of channels per head (dim // num_heads if not set) + window_size (tuple[int]): The height and width of the window. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ - def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.): + def __init__(self, dim, num_heads, head_dim=None, window_size=7, qkv_bias=True, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim - self.window_size = window_size # Wh, Ww + self.window_size = to_2tuple(window_size) # Wh, Ww + win_h, win_w = self.window_size + self.window_area = win_h * win_w self.num_heads = num_heads - head_dim = dim // num_heads + head_dim = head_dim or dim // num_heads + attn_dim = head_dim * num_heads self.scale = head_dim ** -0.5 - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH + self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads)) # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w)) + + self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(attn_dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) + def _get_rel_pos_bias(self) -> torch.Tensor: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + return relative_position_bias.unsqueeze(0) + def forward(self, x, mask: Optional[torch.Tensor] = None): """ Args: @@ -175,20 +194,16 @@ class WindowAttention(nn.Module): mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) + attn = attn + self._get_rel_pos_bias() if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + num_win = mask.shape[0] + attn = attn.view(B_ // num_win, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: @@ -196,7 +211,7 @@ class WindowAttention(nn.Module): attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = (attn @ v).transpose(1, 2).reshape(B_, N, -1) x = self.proj(x) x = self.proj_drop(x) return x @@ -208,8 +223,9 @@ class SwinTransformerBlock(nn.Module): Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. window_size (int): Window size. + num_heads (int): Number of attention heads. + head_dim (int): Enforce the number of channels per head shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True @@ -220,13 +236,13 @@ class SwinTransformerBlock(nn.Module): norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): + def __init__( + self, dim, input_resolution, num_heads=4, head_dim=None, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.input_resolution = input_resolution - self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio @@ -238,31 +254,29 @@ class SwinTransformerBlock(nn.Module): self.norm1 = norm_layer(dim) self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, - attn_drop=attn_drop, proj_drop=drop) + dim, num_heads=num_heads, head_dim=head_dim, window_size=to_2tuple(self.window_size), + qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) cnt = 0 - for h in h_slices: - for w in w_slices: + for h in ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)): + for w in ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)): img_mask[:, h, w, :] = cnt cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = window_partition(img_mask, self.window_size) # num_win, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) @@ -287,11 +301,11 @@ class SwinTransformerBlock(nn.Module): shifted_x = x # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # num_win*B, window_size*window_size, C # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) @@ -320,12 +334,13 @@ class PatchMerging(nn.Module): norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + def __init__(self, input_resolution, dim, out_dim=None, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.out_dim = out_dim or 2 * dim self.norm = norm_layer(4 * dim) + self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False) def forward(self, x): """ @@ -350,15 +365,6 @@ class PatchMerging(nn.Module): return x - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. @@ -368,6 +374,7 @@ class BasicLayer(nn.Module): input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. + head_dim (int): Channels per head (dim // num_heads if not set) window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True @@ -376,47 +383,43 @@ class BasicLayer(nn.Module): drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + def __init__( + self, dim, out_dim, input_resolution, depth, num_heads=4, head_dim=None, + window_size=7, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth - self.use_checkpoint = use_checkpoint + self.grad_checkpointing = False # build blocks - self.blocks = nn.ModuleList([ + self.blocks = nn.Sequential(*[ SwinTransformerBlock( - dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, + dim=dim, input_resolution=input_resolution, num_heads=num_heads, head_dim=head_dim, + window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x): - for blk in self.blocks: - if not torch.jit.is_scripting() and self.use_checkpoint: - x = checkpoint.checkpoint(blk, x) - else: - x = blk(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) if self.downsample is not None: x = self.downsample(x) return x - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - class SwinTransformer(nn.Module): r""" Swin Transformer @@ -431,6 +434,7 @@ class SwinTransformer(nn.Module): embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin Transformer layer. num_heads (tuple(int)): Number of attention heads in different layers. + head_dim (int, tuple(int)): window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True @@ -440,31 +444,26 @@ class SwinTransformer(nn.Module): norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__( self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg', - embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), + embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), head_dim=None, window_size=7, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, weight_init='', **kwargs): + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, weight_init='', **kwargs): super().__init__() assert global_pool in ('', 'avg') self.num_classes = num_classes self.global_pool = global_pool self.num_layers = len(depths) self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) - self.mlp_ratio = mlp_ratio # split image into non-overlapping patches self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) + norm_layer=norm_layer if patch_norm else None) num_patches = self.patch_embed.num_patches self.patch_grid = self.patch_embed.grid_size @@ -473,52 +472,80 @@ class SwinTransformer(nn.Module): self.pos_drop = nn.Dropout(p=drop_rate) # build layers + if not isinstance(embed_dim, (tuple, list)): + embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + embed_out_dim = embed_dim[1:] + [None] + head_dim = to_ntuple(self.num_layers)(head_dim) + window_size = to_ntuple(self.num_layers)(window_size) + mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule layers = [] - for i_layer in range(self.num_layers): + for i in range(self.num_layers): layers += [BasicLayer( - dim=int(embed_dim * 2 ** i_layer), - input_resolution=(self.patch_grid[0] // (2 ** i_layer), self.patch_grid[1] // (2 ** i_layer)), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, + dim=embed_dim[i], + out_dim=embed_out_dim[i], + input_resolution=(self.patch_grid[0] // (2 ** i), self.patch_grid[1] // (2 ** i)), + depth=depths[i], + num_heads=num_heads[i], + head_dim=head_dim[i], + window_size=window_size[i], + mlp_ratio=mlp_ratio[i], qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])], norm_layer=norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, - use_checkpoint=use_checkpoint) - ] + downsample=PatchMerging if (i < self.num_layers - 1) else None + )] self.layers = nn.Sequential(*layers) self.norm = norm_layer(self.num_features) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - self.init_weights(weight_init) + if weight_init != 'skip': + self.init_weights(weight_init) + @torch.jit.ignore def init_weights(self, mode=''): - assert mode in ('jax', 'jax_nlhb', 'nlhb', '') + assert mode in ('jax', 'jax_nlhb', 'moco', '') if self.absolute_pos_embed is not None: trunc_normal_(self.absolute_pos_embed, std=.02) head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. - named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl='jax' in mode), self) + named_apply(get_init_weights_vit(mode, head_bias=head_bias), self) @torch.jit.ignore def no_weight_decay(self): - return {'absolute_pos_embed'} + nwd = {'absolute_pos_embed'} + for n, _ in self.named_parameters(): + if 'relative_position_bias_table' in n: + nwd.add(n) + return nwd @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} + def group_matcher(self, coarse=False): + return dict( + stem=r'^absolute_pos_embed|patch_embed', # stem and embed + blocks=r'^layers.(\d+)' if coarse else [ + (r'^layers.(\d+).downsample', (0,)), + (r'^layers.(\d+).\w+.(\d+)', None), + (r'^norm', (99999,)), + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for l in self.layers: + l.grad_checkpointing = enable + @torch.jit.ignore def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes - self.global_pool = global_pool + if global_pool is not None: + assert global_pool in ('', 'avg') + self.global_pool = global_pool self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): @@ -530,11 +557,14 @@ class SwinTransformer(nn.Module): x = self.norm(x) # B L C return x - def forward(self, x): - x = self.forward_features(x) + def forward_head(self, x, pre_logits: bool = False): if self.global_pool == 'avg': x = x.mean(dim=1) - x = self.head(x) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) return x @@ -547,7 +577,6 @@ def _create_swin_transformer(variant, pretrained=False, **kwargs): return model - @register_model def swin_base_patch4_window12_384(pretrained=False, **kwargs): """ Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k @@ -636,3 +665,34 @@ def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs): model_kwargs = dict( patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_s3_tiny_224(pretrained=False, **kwargs): + """ Swin-S3-T @ 224x224, ImageNet-1k + """ + model_kwargs = dict( + patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer('swin_s3_tiny_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_s3_small_224(pretrained=False, **kwargs): + """ Swin-S3-S @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict( + patch_size=4, window_size=(14, 14, 14, 7), embed_dim=96, depths=(2, 2, 18, 2), + num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer('swin_s3_small_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_s3_base_224(pretrained=False, **kwargs): + """ Swin-S3-B @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict( + patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 30, 2), + num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer('swin_s3_base_224', pretrained=pretrained, **model_kwargs) + diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 60879ccd..63107a27 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -9,6 +9,7 @@ https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT import math import torch import torch.nn as nn +from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import build_model_with_cfg @@ -77,7 +78,8 @@ class Attention(nn.Module): class Block(nn.Module): """ TNT Block """ - def __init__(self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4., + def __init__( + self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() # Inner transformer @@ -153,12 +155,16 @@ class PixelEmbed(nn.Module): class TNT(nn.Module): """ Transformer in Transformer - https://arxiv.org/abs/2103.00112 """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, in_dim=48, depth=12, - num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., - drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4): + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', + embed_dim=768, in_dim=48, depth=12, num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4): super().__init__() + assert global_pool in ('', 'token', 'avg') self.num_classes = num_classes + self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.grad_checkpointing = False self.pixel_embed = PixelEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, in_dim=in_dim, stride=first_stride) @@ -206,11 +212,29 @@ class TNT(nn.Module): def no_weight_decay(self): return {'patch_pos', 'pixel_pos', 'cls_token'} + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^cls_token|patch_pos|pixel_pos|pixel_embed|norm[12]_proj|proj', # stem and embed / pos + blocks=[ + (r'^blocks.(\d+)', None), + (r'^norm', (99999,)), + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'token', 'avg') self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): @@ -222,16 +246,24 @@ class TNT(nn.Module): patch_embed = patch_embed + self.patch_pos patch_embed = self.pos_drop(patch_embed) - for blk in self.blocks: - pixel_embed, patch_embed = blk(pixel_embed, patch_embed) + if self.grad_checkpointing and not torch.jit.is_scripting(): + for blk in self.blocks: + pixel_embed, patch_embed = checkpoint(blk, pixel_embed, patch_embed) + else: + for blk in self.blocks: + pixel_embed, patch_embed = blk(pixel_embed, patch_embed) patch_embed = self.norm(patch_embed) return patch_embed + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + return x if pre_logits else self.head(x) + def forward(self, x): x = self.forward_features(x) - x = x[:, 0] - x = self.head(x) + x = self.forward_head(x) return x diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 77c96aee..f5a1c99a 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -107,8 +107,9 @@ class BasicBlock(nn.Module): class Bottleneck(nn.Module): expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, - act_layer="leaky_relu", aa_layer=None): + def __init__( + self, inplanes, planes, stride=1, downsample=None, use_se=True, + act_layer="leaky_relu", aa_layer=None): super(Bottleneck, self).__init__() self.conv1 = conv2d_iabn( inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer, act_param=1e-3) @@ -130,7 +131,7 @@ class Bottleneck(nn.Module): self.conv3 = conv2d_iabn( planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity") - self.relu = nn.ReLU(inplace=True) + self.act = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride @@ -144,10 +145,9 @@ class Bottleneck(nn.Module): out = self.conv2(out) if self.se is not None: out = self.se(out) - out = self.conv3(out) out = out + shortcut # no inplace - out = self.relu(out) + out = self.act(out) return out @@ -194,7 +194,7 @@ class TResNet(nn.Module): self.num_features = (self.planes * 8) * Bottleneck.expansion self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) - # model initilization + # model initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') @@ -231,6 +231,16 @@ class TResNet(nn.Module): block(self.inplanes, planes, use_se=use_se, aa_layer=aa_layer)) return nn.Sequential(*layers) + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict(stem=r'^body.conv1', blocks=r'^body.layer(\d+)' if coarse else r'^body.layer(\d+).(\d+)') + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore def get_classifier(self): return self.head.fc @@ -241,9 +251,12 @@ class TResNet(nn.Module): def forward_features(self, x): return self.body(x) + def forward_head(self, x, pre_logits: bool = False): + return x if pre_logits else self.head(x) + def forward(self, x): x = self.forward_features(x) - x = self.head(x) + x = self.forward_head(x) return x diff --git a/timm/models/twins.py b/timm/models/twins.py index bb82e1fc..c6ca03ff 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -198,8 +198,9 @@ class GlobalSubSampleAttn(nn.Module): class Block(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None): + def __init__( + self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None): super().__init__() self.norm1 = norm_layer(dim) if ws is None: @@ -273,15 +274,17 @@ class Twins(nn.Module): Adapted from PVT (PyramidVisionTransformer) class at https://github.com/whai362/PVT.git """ def __init__( - self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=(64, 128, 256, 512), - num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=(3, 4, 6, 3), sr_ratios=(8, 4, 2, 1), wss=None, - block_cls=Block): + self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg', + embed_dims=(64, 128, 256, 512), num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), depths=(3, 4, 6, 3), + sr_ratios=(8, 4, 2, 1), wss=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), block_cls=Block): super().__init__() self.num_classes = num_classes + self.global_pool = global_pool self.depths = depths self.embed_dims = embed_dims self.num_features = embed_dims[-1] + self.grad_checkpointing = False img_size = to_2tuple(img_size) prev_chs = in_chans @@ -319,11 +322,34 @@ class Twins(nn.Module): def no_weight_decay(self): return set(['pos_block.' + n for n, p in self.pos_block.named_parameters()]) + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^patch_embeds.0', # stem and embed + blocks=[ + (r'^(?:blocks|patch_embeds|pos_block).(\d+)', None), + ('^norm', (99999,)) + ] if coarse else [ + (r'^blocks.(\d+).(\d+)', None), + (r'^(?:patch_embeds|pos_block).(\d+)', (0,)), + (r'^norm', (99999,)) + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg') + self.global_pool = global_pool self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def _init_weights(self, m): @@ -340,9 +366,6 @@ class Twins(nn.Module): m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1.0) - m.bias.data.zero_() def forward_features(self, x): B = x.shape[0] @@ -359,10 +382,14 @@ class Twins(nn.Module): x = self.norm(x) return x + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + x = x.mean(dim=1) + return x if pre_logits else self.head(x) + def forward(self, x): x = self.forward_features(x) - x = x.mean(dim=1) - x = self.head(x) + x = self.forward_head(x) return x diff --git a/timm/models/vgg.py b/timm/models/vgg.py index 688ab575..f671de22 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -11,7 +11,7 @@ import torch.nn.functional as F from typing import Union, List, Dict, Any, cast from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, checkpoint_seq from .fx_features import register_notrace_module from .layers import ClassifierHead from .registry import register_model @@ -25,7 +25,7 @@ __all__ = [ def _cfg(url='', **kwargs): return { 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1), + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'features.0', 'classifier': 'head.fc', @@ -56,8 +56,9 @@ cfgs: Dict[str, List[Union[str, int]]] = { @register_notrace_module # reason: FX can't symbolically trace control flow in forward method class ConvMlp(nn.Module): - def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0, - drop_rate: float = 0.2, act_layer: nn.Module = None, conv_layer: nn.Module = None): + def __init__( + self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0, + drop_rate: float = 0.2, act_layer: nn.Module = None, conv_layer: nn.Module = None): super(ConvMlp, self).__init__() self.input_kernel_size = kernel_size mid_features = int(out_features * mlp_ratio) @@ -83,23 +84,25 @@ class ConvMlp(nn.Module): class VGG(nn.Module): def __init__( - self, - cfg: List[Any], - num_classes: int = 1000, - in_chans: int = 3, - output_stride: int = 32, - mlp_ratio: float = 1.0, - act_layer: nn.Module = nn.ReLU, - conv_layer: nn.Module = nn.Conv2d, - norm_layer: nn.Module = None, - global_pool: str = 'avg', - drop_rate: float = 0., + self, + cfg: List[Any], + num_classes: int = 1000, + in_chans: int = 3, + output_stride: int = 32, + mlp_ratio: float = 1.0, + act_layer: nn.Module = nn.ReLU, + conv_layer: nn.Module = nn.Conv2d, + norm_layer: nn.Module = None, + global_pool: str = 'avg', + drop_rate: float = 0., ) -> None: super(VGG, self).__init__() assert output_stride == 32 self.num_classes = num_classes self.num_features = 4096 self.drop_rate = drop_rate + self.grad_checkpointing = False + self.use_norm = norm_layer is not None self.feature_info = [] prev_chs = in_chans net_stride = 1 @@ -121,6 +124,7 @@ class VGG(nn.Module): prev_chs = v self.features = nn.Sequential(*layers) self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=f'features.{len(layers) - 1}')) + self.pre_logits = ConvMlp( prev_chs, self.num_features, 7, mlp_ratio=mlp_ratio, drop_rate=drop_rate, act_layer=act_layer, conv_layer=conv_layer) @@ -129,6 +133,16 @@ class VGG(nn.Module): self._initialize_weights() + @torch.jit.ignore + def group_matcher(self, coarse=False): + # this treats BN layers as separate groups for bn variants, a lot of effort to fix that + return dict(stem=r'^features.0', blocks=r'^features.(\d+)') + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore def get_classifier(self): return self.head.fc @@ -139,12 +153,15 @@ class VGG(nn.Module): def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) - x = self.pre_logits(x) return x + def forward_head(self, x: torch.Tensor, pre_logits: bool = False): + x = self.pre_logits(x) + return x if pre_logits else self.head(x) + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.forward_features(x) - x = self.head(x) + x = self.forward_head(x) return x def _initialize_weights(self) -> None: diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 506db56e..112f888b 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -13,7 +13,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, checkpoint_seq from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier from .registry import register_model @@ -41,8 +41,9 @@ default_cfgs = dict( class SpatialMlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, - act_layer=nn.GELU, drop=0., group=8, spatial_conv=False): + def __init__( + self, in_features, hidden_features=None, out_features=None, + act_layer=nn.GELU, drop=0., group=8, spatial_conv=False): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -99,7 +100,7 @@ class Attention(nn.Module): def forward(self, x): B, C, H, W = x.shape x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3) - q, k, v = x[0], x[1], x[2] + q, k, v = x.unbind(0) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) @@ -113,9 +114,10 @@ class Attention(nn.Module): class Block(nn.Module): - def __init__(self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4., - drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm2d, - group=8, attn_disabled=False, spatial_conv=False): + def __init__( + self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4., + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm2d, + group=8, attn_disabled=False, spatial_conv=False): super().__init__() self.spatial_conv = spatial_conv self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -128,9 +130,8 @@ class Block(nn.Module): dim, num_heads=num_heads, head_dim_ratio=head_dim_ratio, attn_drop=attn_drop, proj_drop=drop) self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = SpatialMlp( - in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop, group=group, spatial_conv=spatial_conv) # new setting def forward(self, x): @@ -141,10 +142,11 @@ class Block(nn.Module): class Visformer(nn.Module): - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384, - depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111', - vit_stem=False, group=8, global_pool='avg', conv_init=False, embed_norm=None): + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384, + depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111', + vit_stem=False, group=8, global_pool='avg', conv_init=False, embed_norm=None): super().__init__() img_size = to_2tuple(img_size) self.num_classes = num_classes @@ -160,8 +162,9 @@ class Visformer(nn.Module): self.stage_num1 = self.stage_num3 = depth // 3 self.stage_num2 = depth - self.stage_num1 - self.stage_num3 self.pos_embed = pos_embed - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + self.grad_checkpointing = False + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stage 1 if self.vit_stem: self.stem = None @@ -194,7 +197,7 @@ class Visformer(nn.Module): else: self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size)) self.pos_drop = nn.Dropout(p=drop_rate) - self.stage1 = nn.ModuleList([ + self.stage1 = nn.Sequential(*[ Block( dim=embed_dim//2, num_heads=num_heads, head_dim_ratio=0.5, mlp_ratio=mlp_ratio, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, @@ -211,7 +214,7 @@ class Visformer(nn.Module): img_size = [x // (patch_size // 8) for x in img_size] if self.pos_embed: self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size)) - self.stage2 = nn.ModuleList([ + self.stage2 = nn.Sequential(*[ Block( dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, @@ -228,7 +231,7 @@ class Visformer(nn.Module): img_size = [x // (patch_size // 8) for x in img_size] if self.pos_embed: self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size)) - self.stage3 = nn.ModuleList([ + self.stage3 = nn.Sequential(*[ Block( dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, @@ -255,12 +258,6 @@ class Visformer(nn.Module): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): if self.conv_init: nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') @@ -269,6 +266,22 @@ class Visformer(nn.Module): if m.bias is not None: nn.init.constant_(m.bias, 0.) + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^patch_embed1|pos_embed1|stem', # stem and embed + blocks=[ + (r'^stage(\d+).(\d+)' if coarse else r'^stage(\d+).(\d+)', None), + (r'^(?:patch_embed|pos_embed)(\d+)', (0,)), + (r'^norm', (99999,)) + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): return self.head @@ -283,36 +296,42 @@ class Visformer(nn.Module): # stage 1 x = self.patch_embed1(x) if self.pos_embed: - x = x + self.pos_embed1 - x = self.pos_drop(x) - for b in self.stage1: - x = b(x) + x = self.pos_drop(x + self.pos_embed1) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stage1, x) + else: + x = self.stage1(x) # stage 2 if not self.vit_stem: x = self.patch_embed2(x) if self.pos_embed: - x = x + self.pos_embed2 - x = self.pos_drop(x) - for b in self.stage2: - x = b(x) + x = self.pos_drop(x + self.pos_embed2) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stage2, x) + else: + x = self.stage2(x) # stage3 if not self.vit_stem: x = self.patch_embed3(x) if self.pos_embed: - x = x + self.pos_embed3 - x = self.pos_drop(x) - for b in self.stage3: - x = b(x) + x = self.pos_drop(x + self.pos_embed3) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stage3, x) + else: + x = self.stage3(x) x = self.norm(x) return x + def forward_head(self, x, pre_logits: bool = False): + x = self.global_pool(x) + return x if pre_logits else self.head(x) + def forward(self, x): x = self.forward_features(x) - x = self.global_pool(x) - x = self.head(x) + x = self.forward_head(x) return x diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 6d89f2bf..1d6e79d8 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -27,9 +27,10 @@ from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F +import torch.utils.checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv +from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from .registry import register_model @@ -202,20 +203,23 @@ class Attention(nn.Module): class Block(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): - x = x + self.drop_path(self.attn(self.norm1(x))) - x = x + self.drop_path(self.mlp(self.norm2(x))) + x = x + self.drop_path1(self.attn(self.norm1(x))) + x = x + self.drop_path2(self.mlp(self.norm2(x))) return x @@ -227,8 +231,8 @@ class VisionTransformer(nn.Module): """ def __init__( - self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, global_pool='', + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None): """ @@ -237,6 +241,7 @@ class VisionTransformer(nn.Module): patch_size (int, tuple): patch size in_chans (int): number of input channels num_classes (int): number of classes for classification head + global_pool (str): type of global pooling for final sequence (default: 'token') embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads @@ -252,12 +257,15 @@ class VisionTransformer(nn.Module): act_layer: (nn.Module): MLP activation layer """ super().__init__() + assert global_pool in ('', 'avg', 'token') + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_tokens = 1 - norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) - act_layer = act_layer or nn.GELU + self.grad_checkpointing = False self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) @@ -301,17 +309,15 @@ class VisionTransformer(nn.Module): self.pre_logits = nn.Identity() def init_weights(self, mode=''): - assert mode in ('jax', 'jax_nlhb', 'nlhb', '') + assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. trunc_normal_(self.pos_embed, std=.02) - if 'jax' not in mode: - # init cls token to truncated normal if not following jax impl, jax impl is zero - trunc_normal_(self.cls_token, std=.02) - named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl='jax' in mode), self) + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(get_init_weights_vit(mode, head_bias), self) def _init_weights(self, m): # this fn left here for compat with downstream users - _init_vit_weights(m) + init_weights_vit_timm(m) @torch.jit.ignore() def load_pretrained(self, checkpoint_path, prefix=''): @@ -321,12 +327,26 @@ class VisionTransformer(nn.Module): def no_weight_decay(self): return {'pos_embed', 'cls_token', 'dist_token'} + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool='', representation_size=None): + def reset_classifier(self, num_classes: int, global_pool=None, representation_size=None): self.num_classes = num_classes - self.global_pool = global_pool + if global_pool is not None: + assert global_pool in ('', 'avg', 'token') + self.global_pool = global_pool if representation_size is not None: self._reset_representation(representation_size) final_chs = self.representation_size if self.representation_size else self.embed_dim @@ -336,28 +356,36 @@ class VisionTransformer(nn.Module): x = self.patch_embed(x) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = self.pos_drop(x + self.pos_embed) - x = self.blocks(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) x = self.norm(x) return x - def forward(self, x): - x = self.forward_features(x) - if self.global_pool == 'avg': - x = x[:, self.num_tokens:].mean(dim=1) - else: - x = x[:, 0] + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) x = self.pre_logits(x) - x = self.head(x) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) return x -def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): - """ ViT weight initialization - * When called without n, head_bias, jax_impl args it will behave exactly the same - as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). - * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl - """ +def init_weights_vit_timm(module: nn.Module, name: str = ''): + """ ViT weight initialization, original timm impl (for reproducibility) """ + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): + """ ViT weight initialization, matching JAX (Flax) impl """ if isinstance(module, nn.Linear): if name.startswith('head'): nn.init.zeros_(module.weight) @@ -366,25 +394,35 @@ def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., lecun_normal_(module.weight) nn.init.zeros_(module.bias) else: - if jax_impl: - nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - if 'mlp' in name: - nn.init.normal_(module.bias, std=1e-6) - else: - nn.init.zeros_(module.bias) - else: - trunc_normal_(module.weight, std=.02) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif jax_impl and isinstance(module, nn.Conv2d): - # NOTE conv was left to pytorch default in my original init + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) - elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): - nn.init.zeros_(module.bias) - nn.init.ones_(module.weight) + + +def init_weights_vit_moco(module: nn.Module, name: str = ''): + """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ + if isinstance(module, nn.Linear): + if 'qkv' in name: + # treat the weights of Q, K, V separately + val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) + nn.init.uniform_(module.weight, -val, val) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def get_init_weights_vit(mode='jax', head_bias: float = 0.): + if 'jax' in mode: + return partial(init_weights_vit_jax, head_bias=head_bias) + elif 'moco' in mode: + return init_weights_vit_moco + else: + return init_weights_vit_timm @torch.no_grad() diff --git a/timm/models/volo.py b/timm/models/volo.py new file mode 100644 index 00000000..735453c8 --- /dev/null +++ b/timm/models/volo.py @@ -0,0 +1,750 @@ +""" Vision OutLOoker (VOLO) implementation + +Paper: `VOLO: Vision Outlooker for Visual Recognition` - https://arxiv.org/abs/2106.13112 + +Code adapted from official impl at https://github.com/sail-sg/volo, original copyright in comment below + +Modifications and additions for timm by / Copyright 2022, Ross Wightman +""" +# Copyright 2021 Sea Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_ +from timm.models.registry import register_model +from timm.models.helpers import build_model_with_cfg + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .96, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.conv.0', 'classifier': ('head', 'aux_head'), + **kwargs + } + + +default_cfgs = { + 'volo_d1_224': _cfg( + url='https://github.com/sail-sg/volo/releases/download/volo_1/d1_224_84.2.pth.tar', + crop_pct=0.96), + 'volo_d1_384': _cfg( + url='https://github.com/sail-sg/volo/releases/download/volo_1/d1_384_85.2.pth.tar', + crop_pct=1.0, input_size=(3, 384, 384)), + 'volo_d2_224': _cfg( + url='https://github.com/sail-sg/volo/releases/download/volo_1/d2_224_85.2.pth.tar', + crop_pct=0.96), + 'volo_d2_384': _cfg( + url='https://github.com/sail-sg/volo/releases/download/volo_1/d2_384_86.0.pth.tar', + crop_pct=1.0, input_size=(3, 384, 384)), + 'volo_d3_224': _cfg( + url='https://github.com/sail-sg/volo/releases/download/volo_1/d3_224_85.4.pth.tar', + crop_pct=0.96), + 'volo_d3_448': _cfg( + url='https://github.com/sail-sg/volo/releases/download/volo_1/d3_448_86.3.pth.tar', + crop_pct=1.0, input_size=(3, 448, 448)), + 'volo_d4_224': _cfg( + url='https://github.com/sail-sg/volo/releases/download/volo_1/d4_224_85.7.pth.tar', + crop_pct=0.96), + 'volo_d4_448': _cfg( + url='https://github.com/sail-sg/volo/releases/download/volo_1/d4_448_86.79.pth.tar', + crop_pct=1.15, input_size=(3, 448, 448)), + 'volo_d5_224': _cfg( + url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_224_86.10.pth.tar', + crop_pct=0.96), + 'volo_d5_448': _cfg( + url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_448_87.0.pth.tar', + crop_pct=1.15, input_size=(3, 448, 448)), + 'volo_d5_512': _cfg( + url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_512_87.07.pth.tar', + crop_pct=1.15, input_size=(3, 512, 512)), +} + + +class OutlookAttention(nn.Module): + + def __init__(self, dim, num_heads, kernel_size=3, padding=1, stride=1, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + head_dim = dim // num_heads + self.num_heads = num_heads + self.kernel_size = kernel_size + self.padding = padding + self.stride = stride + self.scale = head_dim ** -0.5 + + self.v = nn.Linear(dim, dim, bias=qkv_bias) + self.attn = nn.Linear(dim, kernel_size ** 4 * num_heads) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride) + self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True) + + def forward(self, x): + B, H, W, C = x.shape + + v = self.v(x).permute(0, 3, 1, 2) # B, C, H, W + + h, w = math.ceil(H / self.stride), math.ceil(W / self.stride) + v = self.unfold(v).reshape( + B, self.num_heads, C // self.num_heads, + self.kernel_size * self.kernel_size, h * w).permute(0, 1, 4, 3, 2) # B,H,N,kxk,C/H + + attn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + attn = self.attn(attn).reshape( + B, h * w, self.num_heads, self.kernel_size * self.kernel_size, + self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4) # B,H,N,kxk,kxk + attn = attn * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).permute(0, 1, 4, 3, 2).reshape(B, C * self.kernel_size * self.kernel_size, h * w) + x = F.fold(x, output_size=(H, W), kernel_size=self.kernel_size, padding=self.padding, stride=self.stride) + + x = self.proj(x.permute(0, 2, 3, 1)) + x = self.proj_drop(x) + + return x + + +class Outlooker(nn.Module): + def __init__( + self, dim, kernel_size, padding, stride=1, num_heads=1, mlp_ratio=3., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, qkv_bias=False + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = OutlookAttention( + dim, num_heads, kernel_size=kernel_size, + padding=padding, stride=stride, + qkv_bias=qkv_bias, attn_drop=attn_drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Attention(nn.Module): + + def __init__( + self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, H, W, C = x.shape + + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, H, W, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Transformer(nn.Module): + + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, + attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop) + + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class ClassAttention(nn.Module): + + def __init__( + self, dim, num_heads=8, head_dim=None, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + if head_dim is not None: + self.head_dim = head_dim + else: + head_dim = dim // num_heads + self.head_dim = head_dim + self.scale = head_dim ** -0.5 + + self.kv = nn.Linear(dim, self.head_dim * self.num_heads * 2, bias=qkv_bias) + self.q = nn.Linear(dim, self.head_dim * self.num_heads, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(self.head_dim * self.num_heads, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + + kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + k, v = kv.unbind(0) + q = self.q(x[:, :1, :]).reshape(B, self.num_heads, 1, self.head_dim) + attn = ((q * self.scale) @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + cls_embed = (attn @ v).transpose(1, 2).reshape(B, 1, self.head_dim * self.num_heads) + cls_embed = self.proj(cls_embed) + cls_embed = self.proj_drop(cls_embed) + return cls_embed + + +class ClassBlock(nn.Module): + + def __init__( + self, dim, num_heads, head_dim=None, mlp_ratio=4., qkv_bias=False, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = ClassAttention( + dim, num_heads=num_heads, head_dim=head_dim, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + cls_embed = x[:, :1] + cls_embed = cls_embed + self.drop_path(self.attn(self.norm1(x))) + cls_embed = cls_embed + self.drop_path(self.mlp(self.norm2(cls_embed))) + return torch.cat([cls_embed, x[:, 1:]], dim=1) + + +def get_block(block_type, **kargs): + if block_type == 'ca': + return ClassBlock(**kargs) + + +def rand_bbox(size, lam, scale=1): + """ + get bounding box as token labeling (https://github.com/zihangJiang/TokenLabeling) + return: bounding box + """ + W = size[1] // scale + H = size[2] // scale + cut_rat = np.sqrt(1. - lam) + cut_w = np.int(W * cut_rat) + cut_h = np.int(H * cut_rat) + + # uniform + cx = np.random.randint(W) + cy = np.random.randint(H) + + bbx1 = np.clip(cx - cut_w // 2, 0, W) + bby1 = np.clip(cy - cut_h // 2, 0, H) + bbx2 = np.clip(cx + cut_w // 2, 0, W) + bby2 = np.clip(cy + cut_h // 2, 0, H) + + return bbx1, bby1, bbx2, bby2 + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding. + Different with ViT use 1 conv layer, we use 4 conv layers to do patch embedding + """ + + def __init__( + self, img_size=224, stem_conv=False, stem_stride=1, + patch_size=8, in_chans=3, hidden_dim=64, embed_dim=384): + super().__init__() + assert patch_size in [4, 8, 16] + if stem_conv: + self.conv = nn.Sequential( + nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False), # 112x112 + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112 + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112 + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + ) + else: + self.conv = None + + self.proj = nn.Conv2d( + hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride) + self.num_patches = (img_size // patch_size) * (img_size // patch_size) + + def forward(self, x): + if self.conv is not None: + x = self.conv(x) + x = self.proj(x) # B, C, H, W + return x + + +class Downsample(nn.Module): + """ Image to Patch Embedding, downsampling between stage1 and stage2 + """ + + def __init__(self, in_embed_dim, out_embed_dim, patch_size=2): + super().__init__() + self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + x = x.permute(0, 3, 1, 2) + x = self.proj(x) # B, C, H, W + x = x.permute(0, 2, 3, 1) + return x + + +def outlooker_blocks( + block_fn, index, dim, layers, num_heads=1, kernel_size=3, padding=1, stride=2, + mlp_ratio=3., qkv_bias=False, attn_drop=0, drop_path_rate=0., **kwargs): + """ + generate outlooker layer in stage1 + return: outlooker layers + """ + blocks = [] + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) + blocks.append( + block_fn( + dim, kernel_size=kernel_size, padding=padding, + stride=stride, num_heads=num_heads, mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, attn_drop=attn_drop, drop_path=block_dpr)) + blocks = nn.Sequential(*blocks) + return blocks + + +def transformer_blocks( + block_fn, index, dim, layers, num_heads, mlp_ratio=3., + qkv_bias=False, attn_drop=0, drop_path_rate=0., **kwargs): + """ + generate transformer layers in stage2 + return: transformer layers + """ + blocks = [] + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) + blocks.append( + block_fn( + dim, num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + drop_path=block_dpr)) + blocks = nn.Sequential(*blocks) + return blocks + + +class VOLO(nn.Module): + """ + Vision Outlooker, the main class of our model + """ + + def __init__( + self, + layers, + img_size=224, + in_chans=3, + num_classes=1000, + global_pool='token', + patch_size=8, + stem_hidden_dim=64, + embed_dims=None, + num_heads=None, + downsamples=(True, False, False, False), + outlook_attention=(True, False, False, False), + mlp_ratio=3.0, + qkv_bias=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=nn.LayerNorm, + post_layers=('ca', 'ca'), + use_aux_head=True, + use_mix_token=False, + pooling_scale=2, + ): + super().__init__() + num_layers = len(layers) + mlp_ratio = to_ntuple(num_layers)(mlp_ratio) + img_size = to_2tuple(img_size) + + self.num_classes = num_classes + self.global_pool = global_pool + self.mix_token = use_mix_token + self.pooling_scale = pooling_scale + self.num_features = embed_dims[-1] + if use_mix_token: # enable token mixing, see token labeling for details. + self.beta = 1.0 + assert global_pool == 'token', "return all tokens if mix_token is enabled" + self.grad_checkpointing = False + + self.patch_embed = PatchEmbed( + stem_conv=True, stem_stride=2, patch_size=patch_size, + in_chans=in_chans, hidden_dim=stem_hidden_dim, + embed_dim=embed_dims[0]) + + # inital positional encoding, we add positional encoding after outlooker blocks + patch_grid = (img_size[0] // patch_size // pooling_scale, img_size[1] // patch_size // pooling_scale) + self.pos_embed = nn.Parameter(torch.zeros(1, patch_grid[0], patch_grid[1], embed_dims[-1])) + self.pos_drop = nn.Dropout(p=drop_rate) + + # set the main block in network + network = [] + for i in range(len(layers)): + if outlook_attention[i]: + # stage 1 + stage = outlooker_blocks( + Outlooker, i, embed_dims[i], layers, num_heads[i], mlp_ratio=mlp_ratio[i], + qkv_bias=qkv_bias, attn_drop=attn_drop_rate, norm_layer=norm_layer) + network.append(stage) + else: + # stage 2 + stage = transformer_blocks( + Transformer, i, embed_dims[i], layers, num_heads[i], mlp_ratio=mlp_ratio[i], qkv_bias=qkv_bias, + drop_path_rate=drop_path_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer) + network.append(stage) + + if downsamples[i]: + # downsampling between two stages + network.append(Downsample(embed_dims[i], embed_dims[i + 1], 2)) + + self.network = nn.ModuleList(network) + + # set post block, for example, class attention layers + self.post_network = None + if post_layers is not None: + self.post_network = nn.ModuleList( + [ + get_block( + post_layers[i], + dim=embed_dims[-1], + num_heads=num_heads[-1], + mlp_ratio=mlp_ratio[-1], + qkv_bias=qkv_bias, + attn_drop=attn_drop_rate, + drop_path=0., + norm_layer=norm_layer) + for i in range(len(post_layers)) + ]) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1])) + trunc_normal_(self.cls_token, std=.02) + + # set output type + if use_aux_head: + self.aux_head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + else: + self.aux_head = None + self.norm = norm_layer(self.num_features) + + # Classifier head + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[ + (r'^network\.(\d+)\.(\d+)', None), + (r'^network\.(\d+)', (0,)), + ], + blocks2=[ + (r'^cls_token', (0,)), + (r'^post_network\.(\d+)', None), + (r'^norm', (99999,)) + ], + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + if self.aux_head is not None: + self.aux_head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_tokens(self, x): + for idx, block in enumerate(self.network): + if idx == 2: + # add positional encoding after outlooker blocks + x = x + self.pos_embed + x = self.pos_drop(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(block, x) + else: + x = block(x) + + B, H, W, C = x.shape + x = x.reshape(B, -1, C) + return x + + def forward_cls(self, x): + B, N, C = x.shape + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat([cls_tokens, x], dim=1) + for block in self.post_network: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(block, x) + else: + x = block(x) + return x + + def forward_train(self, x): + """ A separate forward fn for training with mix_token (if a train script supports). + Combining multiple modes in as single forward with different return types is torchscript hell. + """ + x = self.patch_embed(x) + x = x.permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C + + # mix token, see token labeling for details. + if self.mix_token and self.training: + lam = np.random.beta(self.beta, self.beta) + patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[2] // self.pooling_scale + bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale) + temp_x = x.clone() + sbbx1, sbby1 = self.pooling_scale * bbx1, self.pooling_scale * bby1 + sbbx2, sbby2 = self.pooling_scale * bbx2, self.pooling_scale * bby2 + temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :] + x = temp_x + else: + bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0 + + # step2: tokens learning in the two stages + x = self.forward_tokens(x) + + # step3: post network, apply class attention or not + if self.post_network is not None: + x = self.forward_cls(x) + x = self.norm(x) + + if self.global_pool == 'avg': + x_cls = x.mean(dim=1) + elif self.global_pool == 'token': + x_cls = x[:, 0] + else: + x_cls = x + + if self.aux_head is None: + return x_cls + + x_aux = self.aux_head(x[:, 1:]) # generate classes in all feature tokens, see token labeling + if not self.training: + return x_cls + 0.5 * x_aux.max(1)[0] + + if self.mix_token and self.training: # reverse "mix token", see token labeling for details. + x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1]) + temp_x = x_aux.clone() + temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :] + x_aux = temp_x + x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1]) + + # return these: 1. class token, 2. classes from all feature tokens, 3. bounding box + return x_cls, x_aux, (bbx1, bby1, bbx2, bby2) + + def forward_features(self, x): + x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C + + # step2: tokens learning in the two stages + x = self.forward_tokens(x) + + # step3: post network, apply class attention or not + if self.post_network is not None: + x = self.forward_cls(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + out = x.mean(dim=1) + elif self.global_pool == 'token': + out = x[:, 0] + else: + out = x + if pre_logits: + return out + out = self.head(out) + if self.aux_head is not None: + # generate classes in all feature tokens, see token labeling + aux = self.aux_head(x[:, 1:]) + out = out + 0.5 * aux.max(1)[0] + return out + + def forward(self, x): + """ simplified forward (without mix token training) """ + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_volo(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + return build_model_with_cfg(VOLO, variant, pretrained, **kwargs) + + +@register_model +def volo_d1_224(pretrained=False, **kwargs): + """ VOLO-D1 model, Params: 27M """ + model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs) + model = _create_volo('volo_d1_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d1_384(pretrained=False, **kwargs): + """ VOLO-D1 model, Params: 27M """ + model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs) + model = _create_volo('volo_d1_384', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d2_224(pretrained=False, **kwargs): + """ VOLO-D2 model, Params: 59M """ + model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) + model = _create_volo('volo_d2_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d2_384(pretrained=False, **kwargs): + """ VOLO-D2 model, Params: 59M """ + model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) + model = _create_volo('volo_d2_384', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d3_224(pretrained=False, **kwargs): + """ VOLO-D3 model, Params: 86M """ + model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) + model = _create_volo('volo_d3_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d3_448(pretrained=False, **kwargs): + """ VOLO-D3 model, Params: 86M """ + model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) + model = _create_volo('volo_d3_448', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d4_224(pretrained=False, **kwargs): + """ VOLO-D4 model, Params: 193M """ + model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs) + model = _create_volo('volo_d4_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d4_448(pretrained=False, **kwargs): + """ VOLO-D4 model, Params: 193M """ + model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs) + model = _create_volo('volo_d4_448', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d5_224(pretrained=False, **kwargs): + """ VOLO-D5 model, Params: 296M + stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5 + """ + model_args = dict( + layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), + mlp_ratio=4, stem_hidden_dim=128, **kwargs) + model = _create_volo('volo_d5_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d5_448(pretrained=False, **kwargs): + """ VOLO-D5 model, Params: 296M + stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5 + """ + model_args = dict( + layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), + mlp_ratio=4, stem_hidden_dim=128, **kwargs) + model = _create_volo('volo_d5_448', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d5_512(pretrained=False, **kwargs): + """ VOLO-D5 model, Params: 296M + stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5 + """ + model_args = dict( + layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), + mlp_ratio=4, stem_hidden_dim=128, **kwargs) + model = _create_volo('volo_d5_512', pretrained=pretrained, **model_args) + return model diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index 507e4bb5..59ee470f 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -19,7 +19,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .registry import register_model -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, checkpoint_seq from .layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath,\ create_attn, create_norm_act_layer, get_norm_act_layer @@ -178,8 +178,9 @@ class SequentialAppendList(nn.Sequential): class OsaBlock(nn.Module): - def __init__(self, in_chs, mid_chs, out_chs, layer_per_block, residual=False, - depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path=None): + def __init__( + self, in_chs, mid_chs, out_chs, layer_per_block, residual=False, + depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path=None): super(OsaBlock, self).__init__() self.residual = residual @@ -207,10 +208,7 @@ class OsaBlock(nn.Module): next_in_chs = in_chs + layer_per_block * mid_chs self.conv_concat = ConvNormAct(next_in_chs, out_chs, **conv_kwargs) - if attn: - self.attn = create_attn(attn, out_chs) - else: - self.attn = None + self.attn = create_attn(attn, out_chs) if attn else None self.drop_path = drop_path @@ -231,10 +229,12 @@ class OsaBlock(nn.Module): class OsaStage(nn.Module): - def __init__(self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True, - residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, - drop_path_rates=None): + def __init__( + self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True, + residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, + drop_path_rates=None): super(OsaStage, self).__init__() + self.grad_checkpointing = False if downsample: self.pool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) @@ -258,14 +258,18 @@ class OsaStage(nn.Module): def forward(self, x): if self.pool is not None: x = self.pool(x) - x = self.blocks(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) return x class VovNet(nn.Module): - def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4, - output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path_rate=0.): + def __init__( + self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4, + output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path_rate=0.): """ VovNet (v2) """ super(VovNet, self).__init__() @@ -315,12 +319,23 @@ class VovNet(nn.Module): for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1.) - nn.init.constant_(m.bias, 0.) elif isinstance(m, nn.Linear): nn.init.zeros_(m.bias) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=r'^stages.(\d+)' if coarse else r'^stages.(\d+).blocks.(\d+)', + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): return self.head.fc @@ -331,9 +346,13 @@ class VovNet(nn.Module): x = self.stem(x) return self.stages(x) + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) + def forward(self, x): x = self.forward_features(x) - return self.head(x) + x = self.forward_head(x) + return x def _create_vovnet(variant, pretrained=False, **kwargs): diff --git a/timm/models/xception.py b/timm/models/xception.py index f9428d07..99d02c46 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -21,7 +21,7 @@ normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 """ - +import torch.jit import torch.nn as nn import torch.nn.functional as F @@ -172,6 +172,21 @@ class Xception(nn.Module): m.weight.data.fill_(1) m.bias.data.zero_() + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^conv[12]|bn[12]', + blocks=[ + (r'^block(\d+)', None), + (r'^conv[34]|bn[34]', (99,)), + ], + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, "gradient checkpointing not supported" + + @torch.jit.ignore def get_classifier(self): return self.fc @@ -210,12 +225,15 @@ class Xception(nn.Module): x = self.act4(x) return x - def forward(self, x): - x = self.forward_features(x) + def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) if self.drop_rate: F.dropout(x, self.drop_rate, training=self.training) - x = self.fc(x) + return x if pre_logits else self.fc(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) return x diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index e1156674..e4f66bd3 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -7,11 +7,11 @@ Hacked together by / Copyright 2020 Ross Wightman """ from functools import partial +import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, checkpoint_seq from .layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act_layer from .layers.helpers import to_3tuple from .registry import register_model @@ -39,6 +39,7 @@ default_cfgs = dict( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'), xception41p=_cfg(url=''), + xception65p=_cfg(url=''), ) @@ -167,12 +168,14 @@ class XceptionAligned(nn.Module): """Modified Aligned Xception """ - def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32, preact=False, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0., global_pool='avg'): + def __init__( + self, block_cfg, num_classes=1000, in_chans=3, output_stride=32, preact=False, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0., global_pool='avg'): super(XceptionAligned, self).__init__() + assert output_stride in (8, 16, 32) self.num_classes = num_classes self.drop_rate = drop_rate - assert output_stride in (8, 16, 32) + self.grad_checkpointing = False layer_args = dict(act_layer=act_layer, norm_layer=norm_layer) self.stem = nn.Sequential(*[ @@ -206,6 +209,18 @@ class XceptionAligned(nn.Module): self.head = ClassifierHead( in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=r'^blocks.(\d+)', + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): return self.head.fc @@ -214,13 +229,19 @@ class XceptionAligned(nn.Module): def forward_features(self, x): x = self.stem(x) - x = self.blocks(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) x = self.act(x) return x + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) + def forward(self, x): x = self.forward_features(x) - x = self.head(x) + x = self.forward_head(x) return x @@ -307,3 +328,23 @@ def xception41p(pretrained=False, **kwargs): ] model_args = dict(block_cfg=block_cfg, preact=True, norm_layer=nn.BatchNorm2d, **kwargs) return _xception('xception41p', pretrained=pretrained, **model_args) + + +@register_model +def xception65p(pretrained=False, **kwargs): + """ Modified Aligned Xception-65 w/ Pre-Act + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 16), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True), + ] + model_args = dict( + block_cfg=block_cfg, preact=True, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs) + return _xception('xception65p', pretrained=pretrained, **model_args) diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 91c99fc5..7782d721 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -16,6 +16,7 @@ from functools import partial import torch import torch.nn as nn +from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg @@ -215,8 +216,9 @@ class LPI(nn.Module): class ClassAttentionBlock(nn.Module): """Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239""" - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1., tokens_norm=False): + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1., tokens_norm=False): super().__init__() self.norm1 = norm_layer(dim) @@ -292,8 +294,9 @@ class XCA(nn.Module): class XCABlock(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1.): + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1.): super().__init__() self.norm1 = norm_layer(dim) self.attn = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) @@ -325,9 +328,10 @@ class XCiT(nn.Module): https://github.com/facebookresearch/deit/ """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - act_layer=None, norm_layer=None, cls_attn_layers=2, use_pos_embed=True, eta=1., tokens_norm=False): + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768, + depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + act_layer=None, norm_layer=None, cls_attn_layers=2, use_pos_embed=True, eta=1., tokens_norm=False): """ Args: img_size (int, tuple): input image size @@ -353,14 +357,17 @@ class XCiT(nn.Module): interaction (class LPI) and the patch embedding (class ConvPatchEmbed) """ super().__init__() + assert global_pool in ('', 'avg', 'token') img_size = to_2tuple(img_size) assert (img_size[0] % patch_size == 0) and (img_size[0] % patch_size == 0), \ '`patch_size` should divide image dimensions evenly' + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim - norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) - act_layer = act_layer or nn.GELU + self.global_pool = global_pool + self.grad_checkpointing = False self.patch_embed = ConvPatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, act_layer=act_layer) @@ -396,19 +403,32 @@ class XCiT(nn.Module): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token'} + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=r'^blocks.(\d+)', + cls_attn_blocks=[(r'^cls_attn_blocks.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore def get_classifier(self): return self.head def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg', 'token') + self.global_pool = global_pool self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): @@ -420,24 +440,33 @@ class XCiT(nn.Module): # `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C) pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1) x = x + pos_encoding - x = self.pos_drop(x) for blk in self.blocks: - x = blk(x, Hp, Wp) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, Hp, Wp) + else: + x = blk(x, Hp, Wp) x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) for blk in self.cls_attn_blocks: - x = blk(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) x = self.norm(x) return x + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + return x if pre_logits else self.head(x) + def forward(self, x): x = self.forward_features(x) - x = x[:, 0] - x = self.head(x) + x = self.forward_head(x) return x diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index e1749156..842d18f9 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -1,12 +1,16 @@ """ Optimizer Factory w/ Custom Weight Decay Hacked together by / Copyright 2021 Ross Wightman """ -from typing import Optional +import json +from itertools import islice +from typing import Optional, Callable, Tuple import torch import torch.nn as nn import torch.optim as optim +from timm.models.helpers import group_parameters + from .adabelief import AdaBelief from .adafactor import Adafactor from .adahessian import Adahessian @@ -28,21 +32,122 @@ except ImportError: has_apex = False -def add_weight_decay(model, weight_decay=1e-5, skip_list=()): +def param_groups_weight_decay( + model: nn.Module, + weight_decay=1e-5, + no_weight_decay_list=() +): + no_weight_decay_list = set(no_weight_decay_list) decay = [] no_decay = [] for name, param in model.named_parameters(): if not param.requires_grad: - continue # frozen weights - if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: + continue + + if param.ndim or name.endswith(".bias") or name in no_weight_decay_list: no_decay.append(param) else: decay.append(param) + return [ {'params': no_decay, 'weight_decay': 0.}, {'params': decay, 'weight_decay': weight_decay}] +def _group(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def _layer_map(model, layers_per_group=12, num_groups=None): + def _in_head(n, hp): + if not hp: + return True + elif isinstance(hp, (tuple, list)): + return any([n.startswith(hpi) for hpi in hp]) + else: + return n.startswith(hp) + + head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None) + names_trunk = [] + names_head = [] + for n, _ in model.named_parameters(): + names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n) + + # group non-head layers + num_trunk_layers = len(names_trunk) + if num_groups is not None: + layers_per_group = -(num_trunk_layers // -num_groups) + names_trunk = list(_group(names_trunk, layers_per_group)) + + num_trunk_groups = len(names_trunk) + layer_map = {n: i for i, l in enumerate(names_trunk) for n in l} + layer_map.update({n: num_trunk_groups for n in names_head}) + return layer_map + + +def param_groups_layer_decay( + model: nn.Module, + weight_decay: float = 0.05, + no_weight_decay_list: Tuple[str] = (), + layer_decay: float = .75, + end_layer_decay: Optional[float] = None, +): + """ + Parameter groups for layer-wise lr decay & weight decay + Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 + """ + no_weight_decay_list = set(no_weight_decay_list) + param_group_names = {} # NOTE for debugging + param_groups = {} + + if hasattr(model, 'group_matcher'): + # FIXME interface needs more work + layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True) + else: + # fallback + layer_map = _layer_map(model) + num_layers = max(layer_map.values()) + 1 + layer_max = num_layers - 1 + layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers)) + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + # no decay: all 1D parameters and model specific ones + if param.ndim == 1 or name in no_weight_decay_list: + g_decay = "no_decay" + this_decay = 0. + else: + g_decay = "decay" + this_decay = weight_decay + + layer_id = layer_map.get(name, layer_max) + group_name = "layer_%d_%s" % (layer_id, g_decay) + + if group_name not in param_groups: + this_scale = layer_scales[layer_id] + param_group_names[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "param_names": [], + } + param_groups[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "params": [], + } + + param_group_names[group_name]["param_names"].append(name) + param_groups[group_name]["params"].append(param) + + # FIXME temporary output to debug new feature + print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) + + return list(param_groups.values()) + + def optimizer_kwargs(cfg): """ cfg/argparse to kwargs helper Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn. @@ -56,6 +161,8 @@ def optimizer_kwargs(cfg): kwargs['eps'] = cfg.opt_eps if getattr(cfg, 'opt_betas', None) is not None: kwargs['betas'] = cfg.opt_betas + if getattr(cfg, 'layer_decay', None) is not None: + kwargs['layer_decay'] = cfg.layer_decay if getattr(cfg, 'opt_args', None) is not None: kwargs.update(cfg.opt_args) return kwargs @@ -79,6 +186,8 @@ def create_optimizer_v2( weight_decay: float = 0., momentum: float = 0.9, filter_bias_and_bn: bool = True, + layer_decay: Optional[float] = None, + param_group_fn: Optional[Callable] = None, **kwargs): """ Create an optimizer. @@ -101,11 +210,21 @@ def create_optimizer_v2( """ if isinstance(model_or_params, nn.Module): # a model was passed in, extract parameters and add weight decays to appropriate layers - if weight_decay and filter_bias_and_bn: - skip = {} - if hasattr(model_or_params, 'no_weight_decay'): - skip = model_or_params.no_weight_decay() - parameters = add_weight_decay(model_or_params, weight_decay, skip) + no_weight_decay = {} + if hasattr(model_or_params, 'no_weight_decay'): + no_weight_decay = model_or_params.no_weight_decay() + + if param_group_fn: + parameters = param_group_fn(model_or_params) + elif layer_decay is not None: + parameters = param_groups_layer_decay( + model_or_params, + weight_decay=weight_decay, + layer_decay=layer_decay, + no_weight_decay_list=no_weight_decay) + weight_decay = 0. + elif weight_decay and filter_bias_and_bn: + parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay) weight_decay = 0. else: parameters = model_or_params.parameters() diff --git a/timm/scheduler/scheduler.py b/timm/scheduler/scheduler.py index 21d51509..226d0e76 100644 --- a/timm/scheduler/scheduler.py +++ b/timm/scheduler/scheduler.py @@ -82,7 +82,10 @@ class Scheduler: if not isinstance(values, (list, tuple)): values = [values] * len(self.optimizer.param_groups) for param_group, value in zip(self.optimizer.param_groups, values): - param_group[self.param_group_field] = value + if 'lr_scale' in param_group: + param_group[self.param_group_field] = value * param_group['lr_scale'] + else: + param_group[self.param_group_field] = value def _add_noise(self, lrs, t): if self.noise_range_t is not None: diff --git a/train.py b/train.py index 60ff10d5..ea127251 100755 --- a/train.py +++ b/train.py @@ -112,9 +112,17 @@ parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', parser.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', - help='input batch size for training (default: 128)') + help='Input batch size for training (default: 128)') parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', - help='validation batch size override (default: None)') + help='Validation batch size override (default: None)') +parser.add_argument('--channels-last', action='store_true', default=False, + help='Use channels_last memory layout') +parser.add_argument('--torchscript', dest='torchscript', action='store_true', + help='torch.jit.script the full model') +parser.add_argument('--fuser', default='', type=str, + help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") +parser.add_argument('--grad-checkpointing', action='store_true', default=False, + help='Enable gradient checkpointing through model blocks/stages') # Optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', @@ -131,7 +139,8 @@ parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') parser.add_argument('--clip-mode', type=str, default='norm', help='Gradient clipping mode. One of ("norm", "value", "agc")') - +parser.add_argument('--layer-decay', type=float, default=None, + help='weight decay (default: None)') # Learning rate schedule parameters parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', @@ -188,7 +197,7 @@ parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: None)'), -parser.add_argument('--aug-repeats', type=int, default=0, +parser.add_argument('--aug-repeats', type=float, default=0, help='Number of augmentation repetitions (distributed training only) (default: 0)') parser.add_argument('--aug-splits', type=int, default=0, help='Number of augmentation splits (default: 0, valid: 0 or >=2)') @@ -276,8 +285,6 @@ parser.add_argument('--native-amp', action='store_true', default=False, help='Use Native Torch AMP mixed precision') parser.add_argument('--no-ddp-bb', action='store_true', default=False, help='Force broadcast buffers for native DDP to off.') -parser.add_argument('--channels-last', action='store_true', default=False, - help='Use channels_last memory layout') parser.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no-prefetcher', action='store_true', default=False, @@ -293,10 +300,6 @@ parser.add_argument('--tta', type=int, default=0, metavar='N', parser.add_argument("--local_rank", default=0, type=int) parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, help='use the multi-epochs-loader to save time at the beginning of every epoch') -parser.add_argument('--torchscript', dest='torchscript', action='store_true', - help='convert model torchscript for inference') -parser.add_argument('--fuser', default='', type=str, - help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--log-wandb', action='store_true', default=False, help='log training and validation metrics to wandb') @@ -386,6 +389,9 @@ def main(): assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly + if args.grad_checkpointing: + model.set_grad_checkpointing(enable=True) + if args.local_rank == 0: _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') @@ -458,7 +464,7 @@ def main(): # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: - # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper + # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: