From a4d8fea61eef69fb42b7d8d39428d01f84d9ceb8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 13 Oct 2020 12:49:47 -0700 Subject: [PATCH 1/4] Add model based wd skip support. Improve cross version compat of optimizer factory. Fix #247 --- timm/optim/optim_factory.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index c53be368..80bac373 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -41,7 +41,10 @@ def create_optimizer(args, model, filter_bias_and_bn=True): opt_lower = args.opt.lower() weight_decay = args.weight_decay if weight_decay and filter_bias_and_bn: - parameters = add_weight_decay(model, weight_decay) + skip = {} + if hasattr(model, 'no_weight_decay'): + skip = model.no_weight_decay + parameters = add_weight_decay(model, weight_decay, skip) weight_decay = 0. else: parameters = model.parameters() @@ -50,9 +53,9 @@ def create_optimizer(args, model, filter_bias_and_bn=True): assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' opt_args = dict(lr=args.lr, weight_decay=weight_decay) - if args.opt_eps is not None: + if hasattr(args, 'opt_eps') and args.opt_eps is not None: opt_args['eps'] = args.opt_eps - if args.opt_betas is not None: + if hasattr(args, 'opt_betas') and args.opt_betas is not None: opt_args['betas'] = args.opt_betas opt_split = opt_lower.split('_') From 9305313291ab1966b093abde83d78e1e7e15186d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 13 Oct 2020 12:58:04 -0700 Subject: [PATCH 2/4] Default to old checkpoint format for now, still want compatibility with older torch ver for released models --- avg_checkpoints.py | 6 +++++- clean_checkpoint.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/avg_checkpoints.py b/avg_checkpoints.py index feeac8af..a6921224 100755 --- a/avg_checkpoints.py +++ b/avg_checkpoints.py @@ -103,7 +103,11 @@ def main(): v = v.clamp(float32_info.min, float32_info.max) final_state_dict[k] = v.to(dtype=torch.float32) - torch.save(final_state_dict, args.output) + try: + torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False) + except: + torch.save(final_state_dict, args.output) + with open(args.output, 'rb') as f: sha_hash = hashlib.sha256(f.read()).hexdigest() print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash)) diff --git a/clean_checkpoint.py b/clean_checkpoint.py index af67f3b9..94f184d1 100755 --- a/clean_checkpoint.py +++ b/clean_checkpoint.py @@ -57,7 +57,11 @@ def main(): new_state_dict[name] = v print("=> Loaded state_dict from '{}'".format(args.checkpoint)) - torch.save(new_state_dict, _TEMP_NAME) + try: + torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False) + except: + torch.save(new_state_dict, _TEMP_NAME) + with open(_TEMP_NAME, 'rb') as f: sha_hash = hashlib.sha256(f.read()).hexdigest() From f31933cb374b39e8a6276d08b7b73ead24eea2d0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 13 Oct 2020 13:33:44 -0700 Subject: [PATCH 3/4] Initial Vision Transformer impl w/ patch and hybrid variants. Refactor tuple helpers. --- tests/test_models.py | 4 +- timm/models/__init__.py | 1 + timm/models/layers/__init__.py | 1 + timm/models/layers/cond_conv2d.py | 10 +- timm/models/layers/drop.py | 3 +- timm/models/layers/helpers.py | 10 +- timm/models/layers/median_pool.py | 8 +- timm/models/layers/pool2d_same.py | 12 +- timm/models/rexnet.py | 3 +- timm/models/vision_transformer.py | 377 ++++++++++++++++++++++++++++++ timm/models/xception_aligned.py | 6 +- 11 files changed, 408 insertions(+), 27 deletions(-) create mode 100644 timm/models/vision_transformer.py diff --git a/tests/test_models.py b/tests/test_models.py index d6fcaf79..fddddc31 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -15,9 +15,9 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models - EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d'] + EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', 'vit_*'] else: - EXCLUDE_FILTERS = [] + EXCLUDE_FILTERS = ['vit_*'] MAX_FWD_SIZE = 384 MAX_BWD_SIZE = 128 MAX_FWD_FEAT_SIZE = 448 diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 96719e5e..53765fc8 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -21,6 +21,7 @@ from .selecsls import * from .senet import * from .sknet import * from .tresnet import * +from .vision_transformer import * from .vovnet import * from .xception import * from .xception_aligned import * diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 4d5f8a69..a252b8c1 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -16,6 +16,7 @@ from .create_norm_act import create_norm_act, get_norm_act_layer from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .eca import EcaModule, CecaModule from .evo_norm import EvoNormBatch2d, EvoNormSample2d +from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple from .inplace_abn import InplaceAbn from .mixed_conv2d import MixedConv2d from .norm_act import BatchNormAct2d diff --git a/timm/models/layers/cond_conv2d.py b/timm/models/layers/cond_conv2d.py index 175292b7..8b4bbca8 100644 --- a/timm/models/layers/cond_conv2d.py +++ b/timm/models/layers/cond_conv2d.py @@ -13,7 +13,7 @@ import torch from torch import nn as nn from torch.nn import functional as F -from .helpers import tup_pair +from .helpers import to_2tuple from .conv2d_same import conv2d_same from .padding import get_padding_value @@ -46,13 +46,13 @@ class CondConv2d(nn.Module): self.in_channels = in_channels self.out_channels = out_channels - self.kernel_size = tup_pair(kernel_size) - self.stride = tup_pair(stride) + self.kernel_size = to_2tuple(kernel_size) + self.stride = to_2tuple(stride) padding_val, is_padding_dynamic = get_padding_value( padding, kernel_size, stride=stride, dilation=dilation) self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript - self.padding = tup_pair(padding_val) - self.dilation = tup_pair(dilation) + self.padding = to_2tuple(padding_val) + self.dilation = to_2tuple(dilation) self.groups = groups self.num_experts = num_experts diff --git a/timm/models/layers/drop.py b/timm/models/layers/drop.py index 625f1e70..6de9e3f7 100644 --- a/timm/models/layers/drop.py +++ b/timm/models/layers/drop.py @@ -150,7 +150,8 @@ def drop_path(x, drop_prob: float = 0., training: bool = False): if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob - random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device) + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output diff --git a/timm/models/layers/helpers.py b/timm/models/layers/helpers.py index d9aec8af..8d7b559b 100644 --- a/timm/models/layers/helpers.py +++ b/timm/models/layers/helpers.py @@ -15,11 +15,11 @@ def _ntuple(n): return parse -tup_single = _ntuple(1) -tup_pair = _ntuple(2) -tup_triple = _ntuple(3) -tup_quadruple = _ntuple(4) -ntup = _ntuple +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple diff --git a/timm/models/layers/median_pool.py b/timm/models/layers/median_pool.py index f900229f..40bd71a7 100644 --- a/timm/models/layers/median_pool.py +++ b/timm/models/layers/median_pool.py @@ -3,7 +3,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ import torch.nn as nn import torch.nn.functional as F -from .helpers import tup_pair, tup_quadruple +from .helpers import to_2tuple, to_4tuple class MedianPool2d(nn.Module): @@ -17,9 +17,9 @@ class MedianPool2d(nn.Module): """ def __init__(self, kernel_size=3, stride=1, padding=0, same=False): super(MedianPool2d, self).__init__() - self.k = tup_pair(kernel_size) - self.stride = tup_pair(stride) - self.padding = tup_quadruple(padding) # convert to l, r, t, b + self.k = to_2tuple(kernel_size) + self.stride = to_2tuple(stride) + self.padding = to_4tuple(padding) # convert to l, r, t, b self.same = same def _padding(self, x): diff --git a/timm/models/layers/pool2d_same.py b/timm/models/layers/pool2d_same.py index 2e61b426..5fcd0f1f 100644 --- a/timm/models/layers/pool2d_same.py +++ b/timm/models/layers/pool2d_same.py @@ -7,7 +7,7 @@ import torch.nn as nn import torch.nn.functional as F from typing import List, Tuple, Optional -from .helpers import tup_pair +from .helpers import to_2tuple from .padding import pad_same, get_padding_value @@ -22,8 +22,8 @@ class AvgPool2dSame(nn.AvgPool2d): """ Tensorflow like 'SAME' wrapper for 2D average pooling """ def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): - kernel_size = tup_pair(kernel_size) - stride = tup_pair(stride) + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) def forward(self, x): @@ -42,9 +42,9 @@ class MaxPool2dSame(nn.MaxPool2d): """ Tensorflow like 'SAME' wrapper for 2D max pooling """ def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True): - kernel_size = tup_pair(kernel_size) - stride = tup_pair(stride) - dilation = tup_pair(dilation) + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad) def forward(self, x): diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index a8161836..6444b3c8 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -17,6 +17,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath from .registry import register_model +from .efficientnet_builder import efficientnet_init_weights def _cfg(url=''): @@ -186,7 +187,7 @@ class ReXNetV1(nn.Module): self.head = ClassifierHead(self.num_features, num_classes, global_pool, drop_rate) - # FIXME weight init, the original appears to use PyTorch defaults + efficientnet_init_weights(self) def get_classifier(self): return self.head.fc diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py new file mode 100644 index 00000000..b9857ed2 --- /dev/null +++ b/timm/models/vision_transformer.py @@ -0,0 +1,377 @@ +""" Vision Transformer (ViT) in PyTorch + +This is a WIP attempt to implement Vision Transformers as described in +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - +https://openreview.net/pdf?id=YicbFdNTTy + +The paper is currently under review and there is no official reference impl. The +code here is likely to change in the future and I will not make an effort to maintain +backwards weight compatibility when it does. + +Status/TODO: +* Trained (supervised on ImageNet-1k) my custom 'small' patch model to ~75 top-1 after 4 days, 2x GPU, +no dropout or stochastic depth active +* Need more time for supervised training results with dropout and drop connect active, hparam tuning +* Need more GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune +* There are likely mistakes. If you notice any, I'd love to improve this. This is my first time +fiddling with transformers/multi-head attn. +* Hopefully end up with worthwhile pretrained model at some point... + +Acknowledgments: +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import DropPath, to_2tuple, trunc_normal_ +from .resnet import resnet26d, resnet50d +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': 1.0, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': '', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models + 'vit_small_patch16_224': _cfg(), + 'vit_base_patch16_224': _cfg(), + 'vit_base_patch16_384': _cfg(input_size=(3, 384, 384)), + 'vit_base_patch32_384': _cfg(input_size=(3, 384, 384)), + 'vit_large_patch16_224': _cfg(), + 'vit_large_patch16_384': _cfg(input_size=(3, 384, 384)), + 'vit_large_patch32_384': _cfg(input_size=(3, 384, 384)), + 'vit_huge_patch16_224': _cfg(), + 'vit_huge_patch32_384': _cfg(input_size=(3, 384, 384)), + # hybrid models + 'vit_small_resnet26d_224': _cfg(), + 'vit_small_resnet50d_s3_224': _cfg(), + 'vit_base_resnet26d_224': _cfg(), + 'vit_base_resnet50d_224': _cfg(), +} + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.dropout = nn.Dropout(drop) # seems more common to have Transformer MLP drouput here? + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0.): + super().__init__() + self.scale = 1. / dim ** 0.5 + self.num_heads = num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + q, k, v = qkv[:, :, 0].transpose(1, 2), qkv[:, :, 1].transpose(1, 2), qkv[:, :, 2].transpose(1, 2) + + # TODO benchmark vs above + #qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + #q, k, v = qkv + + attn = (q @ k.transpose(-2, -1)) * self.scale + # FIXME support masking + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., act_layer=nn.GELU, drop=0., drop_path=0.): + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.attn = Attention(dim, num_heads=num_heads, attn_drop=drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = nn.LayerNorm(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, attn_mask=None): + x = x + self.drop_path(self.attn(self.norm1(x), attn_mask=attn_mask)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + Unfold image into fixed size patches, flatten into seq, project to embedding dim. + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, flatten_channels_last=False): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + assert img_size[0] % patch_size[0] == 0, 'image height must be divisible by the patch height' + assert img_size[1] % patch_size[1] == 0, 'image width must be divisible by the patch width' + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + patch_dim = in_chans * patch_size[0] * patch_size[1] + self.img_size = img_size + self.patch_size = patch_size + self.flatten_channels_last = flatten_channels_last + self.num_patches = num_patches + + self.proj = nn.Linear(patch_dim, embed_dim) + + def forward(self, x): + B, C, H, W = x.shape + Ph, Pw = self.patch_size + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + if self.flatten_channels_last: + # flatten patches with channels last like the paper (likely using TF) + x = x.unfold(2, Ph, Ph).unfold(3, Pw, Pw).permute(0, 2, 3, 4, 5, 1).reshape(B, -1, Ph * Pw * C) + else: + x = x.permute(0, 2, 3, 1).unfold(1, Ph, Ph).unfold(2, Pw, Pw).reshape(B, -1, C * Ph * Pw) + x = self.proj(x) + return x + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature + # map for all networks, the feature metadata has reliable channel and stride info, but using + # stride to calc feature dim requires info about padding of each stage that isn't captured. + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., mlp_head=False, drop_rate=0., drop_path_rate=0., + flatten_channels_last=False, hybrid_backbone=None): + super().__init__() + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + flatten_channels_last=flatten_channels_last) + num_patches = self.patch_embed.num_patches + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i]) + for i in range(depth)]) + + self.norm = nn.LayerNorm(embed_dim) + if mlp_head: + # paper diagram suggests 'MLP head', but results in 4M extra parameters vs paper + self.head = Mlp(embed_dim, int(embed_dim * mlp_ratio), num_classes) + else: + # with a single Linear layer as head, the param count within rounding of paper + self.head = nn.Linear(embed_dim, num_classes) + + # FIXME not quite sure what the proper weight init is supposed to be, + # normal / trunc normal w/ std == .02 similar to other Bert like transformers + trunc_normal_(self.pos_embed, std=.02) # embeddings same as weights? + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @property + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward(self, x, attn_mask=None): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x += self.pos_embed + + for blk in self.blocks: + x = blk(x, attn_mask=attn_mask) + + x = self.norm(x[:, 0]) + x = self.head(x) + return x + + +@register_model +def vit_small_patch16_224(pretrained=False, **kwargs): + model = VisionTransformer(patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., **kwargs) + model.default_cfg = default_cfgs['vit_small_patch16_224'] + return model + + +@register_model +def vit_base_patch16_224(pretrained=False, **kwargs): + model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) + model.default_cfg = default_cfgs['vit_base_patch16_224'] + return model + + +@register_model +def vit_base_patch16_384(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) + model.default_cfg = default_cfgs['vit_base_patch16_384'] + return model + + +@register_model +def vit_base_patch32_384(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=384, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) + model.default_cfg = default_cfgs['vit_base_patch32_384'] + return model + + +@register_model +def vit_large_patch16_224(pretrained=False, **kwargs): + model = VisionTransformer(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs) + model.default_cfg = default_cfgs['vit_large_patch16_224'] + return model + + +@register_model +def vit_large_patch16_384(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs) + model.default_cfg = default_cfgs['vit_large_patch16_384'] + return model + + +@register_model +def vit_large_patch32_384(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=384, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs) + model.default_cfg = default_cfgs['vit_large_patch32_384'] + return model + + +@register_model +def vit_huge_patch16_224(pretrained=False, **kwargs): + model = VisionTransformer(patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs) + model.default_cfg = default_cfgs['vit_huge_patch16_224'] + return model + + +@register_model +def vit_huge_patch32_384(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=384, patch_size=32, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs) + model.default_cfg = default_cfgs['vit_huge_patch32_384'] + return model + + +@register_model +def vit_small_resnet26d_224(pretrained=False, **kwargs): + pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing + backbone = resnet26d(pretrained=pretrained_backbone, features_only=True, out_indices=[4]) + model = VisionTransformer( + img_size=224, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) + model.default_cfg = default_cfgs['vit_small_resnet26d_224'] + return model + + +@register_model +def vit_small_resnet50d_s3_224(pretrained=False, **kwargs): + pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing + backbone = resnet50d(pretrained=pretrained_backbone, features_only=True, out_indices=[3]) + model = VisionTransformer( + img_size=224, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) + model.default_cfg = default_cfgs['vit_small_resnet50d_s3_224'] + return model + + +@register_model +def vit_base_resnet26d_224(pretrained=False, **kwargs): + pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing + backbone = resnet26d(pretrained=pretrained_backbone, features_only=True, out_indices=[4]) + model = VisionTransformer( + img_size=224, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs) + model.default_cfg = default_cfgs['vit_base_resnet26d_224'] + return model + + +@register_model +def vit_base_resnet50d_224(pretrained=False, **kwargs): + pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing + backbone = resnet50d(pretrained=pretrained_backbone, features_only=True, out_indices=[4]) + model = VisionTransformer( + img_size=224, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs) + model.default_cfg = default_cfgs['vit_base_resnet50d_224'] + return model + + diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index f3a4a50a..e6b21576 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .helpers import build_model_with_cfg from .layers import ClassifierHead, ConvBnAct, create_conv2d -from .layers.helpers import tup_triple +from .layers.helpers import to_3tuple from .registry import register_model __all__ = ['XceptionAligned'] @@ -85,7 +85,7 @@ class XceptionModule(nn.Module): start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None, norm_kwargs=None): super(XceptionModule, self).__init__() norm_kwargs = norm_kwargs if norm_kwargs is not None else {} - out_chs = tup_triple(out_chs) + out_chs = to_3tuple(out_chs) self.in_channels = in_chs self.out_channels = out_chs[-1] self.no_skip = no_skip @@ -142,7 +142,7 @@ class XceptionAligned(nn.Module): b['dilation'] = curr_dilation if b['stride'] > 1: self.feature_info += [dict( - num_chs=tup_triple(b['out_chs'])[-2], reduction=curr_stride, module=f'blocks.{i}.stack.act3')] + num_chs=to_3tuple(b['out_chs'])[-2], reduction=curr_stride, module=f'blocks.{i}.stack.act3')] next_stride = curr_stride * b['stride'] if next_stride > output_stride: curr_dilation *= b['stride'] From be53107e8a3795c9b9ce8d3dab788a4f44602d80 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 13 Oct 2020 14:51:08 -0700 Subject: [PATCH 4/4] Update README, ensure vit excluded from all tests (not ready) --- README.md | 9 +++++++++ tests/test_models.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a56e47a6..c03a40e2 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,12 @@ ## What's New +### Oct 13, 2020 +* Initial impl of Vision Transformer models. Both patch and hybrid (CNN backbone) variants. Currently trying to train... +* Adafactor and AdaHessian (FP32 only, no AMP) optimizers +* EdgeTPU-M (`efficientnet_em`) model trained in PyTorch, 79.3 top-1 +* Pip release, doc updates pending a few more changes... + ### Sept 18, 2020 * New ResNet 'D' weights. 72.7 (top-1) ResNet-18-D, 77.1 ResNet-34-D, 80.5 ResNet-50-D * Added a few untrained defs for other ResNet models (66D, 101D, 152D, 200/200D) @@ -124,6 +130,7 @@ A full version of the list below with source links can be found in the [document * SelecSLS - https://arxiv.org/abs/1907.00837 * Selective Kernel Networks - https://arxiv.org/abs/1903.06586 * TResNet - https://arxiv.org/abs/2003.13630 +* Vision Transformer - https://openreview.net/forum?id=YicbFdNTTy * VovNet V2 and V1 - https://arxiv.org/abs/1911.06667 * Xception - https://arxiv.org/abs/1610.02357 * Xception (Modified Aligned, Gluon) - https://arxiv.org/abs/1802.02611 @@ -162,6 +169,8 @@ Several (less common) features that I often utilize in my projects are included. * `lookahead` adapted from impl by [Liam](https://github.com/alphadl/lookahead.pytorch) (https://arxiv.org/abs/1907.08610) * `fused` optimizers by name with [NVIDIA Apex](https://github.com/NVIDIA/apex/tree/master/apex/optimizers) installed * `adamp` and `sgdp` by [Naver ClovAI](https://github.com/clovaai) (https://arxiv.org/abs/2006.08217) + * `adafactor` adapted from [FAIRSeq impl](https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py) (https://arxiv.org/abs/1804.04235) + * `adahessian` by [David Samuel](https://github.com/davda54/ada-hessian) (https://arxiv.org/abs/2006.00719) * Random Erasing from [Zhun Zhong](https://github.com/zhunzhong07/Random-Erasing/blob/master/transforms.py) (https://arxiv.org/abs/1708.04896) * Mixup (https://arxiv.org/abs/1710.09412) * CutMix (https://arxiv.org/abs/1905.04899) diff --git a/tests/test_models.py b/tests/test_models.py index fddddc31..c673dc96 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -68,7 +68,7 @@ def test_model_backward(model_name, batch_size): @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models()) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=['vit_*'])) @pytest.mark.parametrize('batch_size', [1]) def test_model_default_cfgs(model_name, batch_size): """Run a single forward pass with each model"""