diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 53765fc8..fba6d1b8 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -16,6 +16,7 @@ from .regnet import * from .res2net import * from .resnest import * from .resnet import * +from .resnetv2 import * from .rexnet import * from .selecsls import * from .senet import * diff --git a/timm/models/factory.py b/timm/models/factory.py index 70209c96..a7b6c90e 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -6,8 +6,6 @@ from .layers import set_layer_config def create_model( model_name, pretrained=False, - num_classes=1000, - in_chans=3, checkpoint_path='', scriptable=None, exportable=None, @@ -18,8 +16,6 @@ def create_model( Args: model_name (str): name of model to instantiate pretrained (bool): load pretrained ImageNet-1k weights if true - num_classes (int): number of classes for final fully connected layer (default: 1000) - in_chans (int): number of input channels / colors (default: 3) checkpoint_path (str): path of checkpoint to load after model is initialized scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) @@ -30,7 +26,7 @@ def create_model( global_pool (str): global pool type (default: 'avg') **: other kwargs are model specific """ - model_args = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans) + model_args = dict(pretrained=pretrained) # Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3']) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 77b98dc6..2a15e528 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -11,7 +11,7 @@ from typing import Callable import torch import torch.nn as nn -import torch.utils.model_zoo as model_zoo +from torch.hub import get_dir, load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .layers import Conv2dSame, Linear @@ -88,15 +88,70 @@ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, raise FileNotFoundError() -def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True): +def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_hash=False): + r"""Loads a custom (read non .pth) weight file + + Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls + a passed in custom load fun, or the `load_pretrained` model member fn. + + If the object is already present in `model_dir`, it's deserialized and returned. + The default value of `model_dir` is ``/checkpoints`` where + `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. + + Args: + model: The instantiated model to load weights into + cfg (dict): Default pretrained model cfg + load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named + 'laod_pretrained' on the model will be called if it exists + progress (bool, optional): whether or not to display a progress bar to stderr. Default: False + check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention + ``filename-.ext`` where ```` is the first eight or more + digits of the SHA256 hash of the contents of the file. The hash is used to + ensure unique names and to verify the contents of the file. Default: False + """ if cfg is None: cfg = getattr(model, 'default_cfg') if cfg is None or 'url' not in cfg or not cfg['url']: - _logger.warning("Pretrained model URL is invalid, using random initialization.") + _logger.warning("Pretrained model URL does not exist, using random initialization.") return + url = cfg['url'] + + # Issue warning to move data if old env is set + if os.getenv('TORCH_MODEL_ZOO'): + _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') + + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = None + if check_hash: + r = HASH_REGEX.search(filename) # r is Optional[Match[str]] + hash_prefix = r.group(1) if r else None + download_url_to_file(url, cached_file, hash_prefix, progress=progress) + + if load_fn is not None: + load_fn(model, cached_file) + elif hasattr(model, 'load_pretrained'): + model.load_pretrained(cached_file) + else: + _logger.warning("Valid function to load pretrained weights is not available, using random initialization.") + - state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu') +def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): + if cfg is None: + cfg = getattr(model, 'default_cfg') + if cfg is None or 'url' not in cfg or not cfg['url']: + _logger.warning("Pretrained model URL does not exist, using random initialization.") + return + state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu') if filter_fn is not None: state_dict = filter_fn(state_dict) @@ -269,6 +324,7 @@ def build_model_with_cfg( feature_cfg: dict = None, pretrained_strict: bool = True, pretrained_filter_fn: Callable = None, + pretrained_custom_load: bool = False, **kwargs): pruned = kwargs.pop('pruned', False) features = False @@ -289,10 +345,13 @@ def build_model_with_cfg( # for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) if pretrained: - load_pretrained( - model, - num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3), - filter_fn=pretrained_filter_fn, strict=pretrained_strict) + if pretrained_custom_load: + load_custom_pretrained(model) + else: + load_pretrained( + model, + num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3), + filter_fn=pretrained_filter_fn, strict=pretrained_strict) if features: feature_cls = FeatureListNet diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index dac1beb8..142377a9 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -7,7 +7,7 @@ from .classifier import ClassifierHead, create_classifier from .cond_conv2d import CondConv2d, get_condconv_initializer from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ set_layer_config -from .conv2d_same import Conv2dSame +from .conv2d_same import Conv2dSame, conv2d_same from .conv_bn_act import ConvBnAct from .create_act import create_act_layer, get_act_layer, get_act_fn from .create_attn import create_attn @@ -20,8 +20,8 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple from .inplace_abn import InplaceAbn from .linear import Linear from .mixed_conv2d import MixedConv2d -from .norm_act import BatchNormAct2d -from .padding import get_padding +from .norm_act import BatchNormAct2d, GroupNormAct +from .padding import get_padding, get_same_padding, pad_same from .pool2d_same import AvgPool2dSame, create_pool2d from .se import SEModule from .selective_kernel import SelectiveKernelConv diff --git a/timm/models/layers/classifier.py b/timm/models/layers/classifier.py index 89fe5458..516cc6c9 100644 --- a/timm/models/layers/classifier.py +++ b/timm/models/layers/classifier.py @@ -9,31 +9,43 @@ from .adaptive_avgmax_pool import SelectAdaptivePool2d from .linear import Linear -def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): - flatten = not use_conv # flatten when we use a Linear layer after pooling +def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): + flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling if not pool_type: assert num_classes == 0 or use_conv,\ 'Pooling can only be disabled if classifier is also removed or conv classifier is used' - flatten = False # disable flattening if pooling is pass-through (no pooling) - global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten) + flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) + global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) num_pooled_features = num_features * global_pool.feat_mult() + return global_pool, num_pooled_features + + +def _create_fc(num_features, num_classes, pool_type='avg', use_conv=False): if num_classes <= 0: fc = nn.Identity() # pass-through (no classifier) elif use_conv: - fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True) + fc = nn.Conv2d(num_features, num_classes, 1, bias=True) else: # NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue - fc = Linear(num_pooled_features, num_classes, bias=True) + fc = Linear(num_features, num_classes, bias=True) + return fc + + +def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): + global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv) + fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) return global_pool, fc class ClassifierHead(nn.Module): """Classifier head w/ configurable global pooling and dropout.""" - def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.): + def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): super(ClassifierHead, self).__init__() self.drop_rate = drop_rate - self.global_pool, self.fc = create_classifier(in_chs, num_classes, pool_type=pool_type) + self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) + self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) + self.flatten_after_fc = use_conv and pool_type def forward(self, x): x = self.global_pool(x) diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index bddf9b26..e3fe3940 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -68,8 +68,8 @@ class BatchNormAct2d(nn.BatchNorm2d): class GroupNormAct(nn.GroupNorm): - - def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, + # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args + def __init__(self, num_channels, num_groups, eps=1e-5, affine=True, apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine) if isinstance(act_layer, str): diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 18b3725f..60e1a276 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -403,7 +403,7 @@ class ReductionCell1(nn.Module): class NASNetALarge(nn.Module): """NASNetALarge (6 @ 4032) """ - def __init__(self, num_classes=1000, in_chans=1, stem_size=96, channel_multiplier=2, + def __init__(self, num_classes=1000, in_chans=3, stem_size=96, channel_multiplier=2, num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'): super(NASNetALarge, self).__init__() self.num_classes = num_classes diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py new file mode 100644 index 00000000..6611ae49 --- /dev/null +++ b/timm/models/resnetv2.py @@ -0,0 +1,578 @@ +"""Pre-Activation ResNet v2 with GroupNorm and Weight Standardization. + +A PyTorch implementation of ResNetV2 adapted from the Google Big-Transfoer (BiT) source code +at https://github.com/google-research/big_transfer to match timm interfaces. The BiT weights have +been included here as pretrained models from their original .NPZ checkpoints. + +Additionally, supports non pre-activation bottleneck for use as a backbone for Vision Transfomers (ViT) and +extra padding support to allow porting of official Hybrid ResNet pretrained weights from +https://github.com/google-research/vision_transformer + +Thanks to the Google team for the above two repositories and associated papers. + +Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020. +""" +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict # pylint: disable=g-importing-member + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg +from .registry import register_model +from .layers import get_padding, GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, conv2d_same + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 480, 480), 'pool_size': (7, 7), + 'crop_pct': 1.0, 'interpolation': 'bilinear', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = { + # pretrained on imagenet21k, finetuned on imagenet1k + 'resnetv2_50x1_bitm': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz'), + 'resnetv2_50x3_bitm': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz'), + 'resnetv2_101x1_bitm': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz'), + 'resnetv2_101x3_bitm': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz'), + 'resnetv2_152x2_bitm': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz'), + 'resnetv2_152x4_bitm': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz'), + + # trained on imagenet-21k + 'resnetv2_50x1_bitm_in21k': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz', + num_classes=21843), + 'resnetv2_50x3_bitm_in21k': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz', + num_classes=21843), + 'resnetv2_101x1_bitm_in21k': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz', + num_classes=21843), + 'resnetv2_101x3_bitm_in21k': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz', + num_classes=21843), + 'resnetv2_152x2_bitm_in21k': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz', + num_classes=21843), + 'resnetv2_152x4_bitm_in21k': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz', + num_classes=21843), + + + # trained on imagenet-1k + 'resnetv2_50x1_bits': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-S-R50x1-ILSVRC2012.npz'), + 'resnetv2_50x3_bits': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-S-R50x3-ILSVRC2012.npz'), + 'resnetv2_101x1_bits': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-S-R101x3-ILSVRC2012.npz'), + 'resnetv2_101x3_bits': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-S-R101x3-ILSVRC2012.npz'), + 'resnetv2_152x2_bits': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-S-R152x2-ILSVRC2012.npz'), + 'resnetv2_152x4_bits': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-S-R152x4-ILSVRC2012.npz'), +} + + +def make_div(v, divisor=8): + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class StdConv2d(nn.Conv2d): + + def __init__( + self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5): + padding = get_padding(kernel_size, stride, dilation) + super().__init__( + in_channel, out_channels, kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=bias, groups=groups) + self.eps = eps + + def forward(self, x): + w = self.weight + v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) + w = (w - m) / (torch.sqrt(v) + self.eps) + x = F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +class StdConv2dSame(nn.Conv2d): + """StdConv2d w/ TF compatible SAME padding. Used for ViT Hybrid model. + """ + def __init__( + self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5): + padding = get_padding(kernel_size, stride, dilation) + super().__init__( + in_channel, out_channels, kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=bias, groups=groups) + self.eps = eps + + def forward(self, x): + w = self.weight + v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) + w = (w - m) / (torch.sqrt(v) + self.eps) + x = conv2d_same(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +def tf2th(conv_weights): + """Possibly convert HWIO to OIHW.""" + if conv_weights.ndim == 4: + conv_weights = conv_weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(conv_weights) + + +class PreActBottleneck(nn.Module): + """Pre-activation (v2) bottleneck block. + + Follows the implementation of "Identity Mappings in Deep Residual Networks": + https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua + + Except it puts the stride on 3x3 conv when available. + """ + + def __init__( + self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1, + act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.): + super().__init__() + first_dilation = first_dilation or dilation + conv_layer = conv_layer or StdConv2d + norm_layer = norm_layer or partial(GroupNormAct, num_groups=32) + out_chs = out_chs or in_chs + mid_chs = make_div(out_chs * bottle_ratio) + + if proj_layer is not None: + self.downsample = proj_layer( + in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, preact=True, + conv_layer=conv_layer, norm_layer=norm_layer) + else: + self.downsample = None + + self.norm1 = norm_layer(in_chs) + self.conv1 = conv_layer(in_chs, mid_chs, 1) + self.norm2 = norm_layer(mid_chs) + self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) + self.norm3 = norm_layer(mid_chs) + self.conv3 = conv_layer(mid_chs, out_chs, 1) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + + def forward(self, x): + x_preact = self.norm1(x) + + # shortcut branch + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x_preact) + + # residual branch + x = self.conv1(x_preact) + x = self.conv2(self.norm2(x)) + x = self.conv3(self.norm3(x)) + x = self.drop_path(x) + return x + shortcut + + +class Bottleneck(nn.Module): + """Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT. + """ + def __init__( + self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1, + act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.): + super().__init__() + first_dilation = first_dilation or dilation + act_layer = act_layer or nn.ReLU + conv_layer = conv_layer or StdConv2d + norm_layer = norm_layer or partial(GroupNormAct, num_groups=32) + out_chs = out_chs or in_chs + mid_chs = make_div(out_chs * bottle_ratio) + + if proj_layer is not None: + self.downsample = proj_layer( + in_chs, out_chs, stride=stride, dilation=dilation, preact=False, + conv_layer=conv_layer, norm_layer=norm_layer) + else: + self.downsample = None + + self.conv1 = conv_layer(in_chs, mid_chs, 1) + self.norm1 = norm_layer(mid_chs) + self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) + self.norm2 = norm_layer(mid_chs) + self.conv3 = conv_layer(mid_chs, out_chs, 1) + self.norm3 = norm_layer(out_chs, apply_act=False) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + self.act3 = act_layer(inplace=True) + + def forward(self, x): + # shortcut branch + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + # residual + x = self.conv1(x) + x = self.norm1(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.conv3(x) + x = self.norm3(x) + x = self.act3(x + shortcut) + return x + + +class DownsampleConv(nn.Module): + def __init__( + self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True, + conv_layer=None, norm_layer=None): + super(DownsampleConv, self).__init__() + self.conv = conv_layer(in_chs, out_chs, 1, stride=stride) + self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False) + + def forward(self, x): + return self.norm(self.conv(x)) + + +class DownsampleAvg(nn.Module): + def __init__( + self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, + preact=True, conv_layer=None, norm_layer=None): + """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" + super(DownsampleAvg, self).__init__() + avg_stride = stride if dilation == 1 else 1 + if stride > 1 or dilation > 1: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + else: + self.pool = nn.Identity() + self.conv = conv_layer(in_chs, out_chs, 1, stride=1) + self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False) + + def forward(self, x): + return self.norm(self.conv(self.pool(x))) + + +class ResNetStage(nn.Module): + """ResNet Stage.""" + def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1, + avg_down=False, block_dpr=None, block_fn=PreActBottleneck, + act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs): + super(ResNetStage, self).__init__() + first_dilation = 1 if dilation in (1, 2) else 2 + layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer) + proj_layer = DownsampleAvg if avg_down else DownsampleConv + prev_chs = in_chs + self.blocks = nn.Sequential() + for block_idx in range(depth): + drop_path_rate = block_dpr[block_idx] if block_dpr else 0. + stride = stride if block_idx == 0 else 1 + self.blocks.add_module(str(block_idx), block_fn( + prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups, + first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate, + **layer_kwargs, **block_kwargs)) + prev_chs = out_chs + first_dilation = dilation + proj_layer = None + + def forward(self, x): + x = self.blocks(x) + return x + + +def create_stem(in_chs, out_chs, stem_type='', preact=True, conv_layer=None, norm_layer=None): + stem = OrderedDict() + assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same') + + # NOTE conv padding mode can be changed by overriding the conv_layer def + if 'deep' in stem_type: + # A 3 deep 3x3 conv stack as in ResNet V1D models + mid_chs = out_chs // 2 + stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2) + stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1) + stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1) + else: + # The usual 7x7 stem conv + stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2) + + if not preact: + stem['norm'] = norm_layer(out_chs) + + if 'fixed' in stem_type: + # 'fixed' SAME padding approximation that is used in BiT models + stem['pad'] = nn.ConstantPad2d(1, 0) + stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=0) + elif 'same' in stem_type: + # full, input size based 'SAME' padding, used in ViT Hybrid model + stem['pool'] = create_pool2d('max', kernel_size=3, stride=2, padding='same') + else: + # the usual PyTorch symmetric padding + stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + return nn.Sequential(stem) + + +class ResNetV2(nn.Module): + """Implementation of Pre-activation (v2) ResNet mode. + """ + + def __init__(self, layers, channels=(256, 512, 1024, 2048), + num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, + width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True, + act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32), + drop_rate=0., drop_path_rate=0.): + super().__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + wf = width_factor + + self.feature_info = [] + stem_chs = make_div(stem_chs * wf) + self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer) + if not preact: + self.feature_info.append(dict(num_chs=stem_chs, reduction=4, module='stem')) + + prev_chs = stem_chs + curr_stride = 4 + dilation = 1 + block_fn = PreActBottleneck if preact else Bottleneck + self.stages = nn.Sequential() + for stage_idx, (d, c) in enumerate(zip(layers, channels)): + out_chs = make_div(c * wf) + stride = 1 if stage_idx == 0 else 2 + if curr_stride >= output_stride: + dilation *= stride + stride = 1 + if preact: + self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}.norm1')] + stage = ResNetStage( + prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down, + act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_fn=block_fn) + prev_chs = out_chs + curr_stride *= stride + if not preact: + self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')] + self.stages.add_module(str(stage_idx), stage) + + self.num_features = prev_chs + self.norm = norm_layer(self.num_features) if preact else nn.Identity() + if preact: + self.feature_info += [dict(num_chs=self.num_features, reduction=curr_stride, module=f'norm')] + self.head = ClassifierHead( + self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True) + + for n, m in self.named_modules(): + if isinstance(m, nn.Linear) or ('.fc' in n and isinstance(m, nn.Conv2d)): + nn.init.normal_(m.weight, mean=0.0, std=0.01) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead( + self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + if not self.head.global_pool.is_identity(): + x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled) + return x + + def load_pretrained(self, checkpoint_path, prefix='resnet/'): + import numpy as np + weights = np.load(checkpoint_path) + with torch.no_grad(): + self.stem.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])) + self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma'])) + self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta'])) + self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel'])) + self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias'])) + for i, (sname, stage) in enumerate(self.stages.named_children()): + for j, (bname, block) in enumerate(stage.blocks.named_children()): + convname = 'standardized_conv2d' + block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/' + block.conv1.weight.copy_(tf2th(weights[f'{block_prefix}a/{convname}/kernel'])) + block.conv2.weight.copy_(tf2th(weights[f'{block_prefix}b/{convname}/kernel'])) + block.conv3.weight.copy_(tf2th(weights[f'{block_prefix}c/{convname}/kernel'])) + block.norm1.weight.copy_(tf2th(weights[f'{block_prefix}a/group_norm/gamma'])) + block.norm2.weight.copy_(tf2th(weights[f'{block_prefix}b/group_norm/gamma'])) + block.norm3.weight.copy_(tf2th(weights[f'{block_prefix}c/group_norm/gamma'])) + block.norm1.bias.copy_(tf2th(weights[f'{block_prefix}a/group_norm/beta'])) + block.norm2.bias.copy_(tf2th(weights[f'{block_prefix}b/group_norm/beta'])) + block.norm3.bias.copy_(tf2th(weights[f'{block_prefix}c/group_norm/beta'])) + if block.downsample is not None: + w = weights[f'{block_prefix}a/proj/{convname}/kernel'] + block.downsample.conv.weight.copy_(tf2th(w)) + + +def _create_resnetv2(variant, pretrained=False, **kwargs): + # FIXME feature map extraction is not setup properly for pre-activation mode right now + return build_model_with_cfg( + ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_custom_load=True, + feature_cfg=dict(flatten_sequential=True), **kwargs) + + +@register_model +def resnetv2_50x1_bitm(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50x1_bitm', pretrained=pretrained, + layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs) + + +@register_model +def resnetv2_50x3_bitm(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50x3_bitm', pretrained=pretrained, + layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs) + + +@register_model +def resnetv2_101x1_bitm(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_101x1_bitm', pretrained=pretrained, + layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs) + + +@register_model +def resnetv2_101x3_bitm(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_101x3_bitm', pretrained=pretrained, + layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs) + + +@register_model +def resnetv2_152x2_bitm(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_152x2_bitm', pretrained=pretrained, + layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs) + + +@register_model +def resnetv2_152x4_bitm(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_152x4_bitm', pretrained=pretrained, + layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs) + + +@register_model +def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50x1_bitm', pretrained=pretrained, + layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs) + + +@register_model +def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50x3_bitm', pretrained=pretrained, + layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs) + + +@register_model +def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_101x1_bitm', pretrained=pretrained, + layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs) + + +@register_model +def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_101x3_bitm', pretrained=pretrained, + layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs) + + +@register_model +def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_152x2_bitm', pretrained=pretrained, + layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs) + + +@register_model +def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_152x4_bitm', pretrained=pretrained, + layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs) + + +@register_model +def resnetv2_50x1_bits(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50x1_bits', pretrained=pretrained, + layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs) + + +@register_model +def resnetv2_50x3_bits(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50x3_bits', pretrained=pretrained, + layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs) + + +@register_model +def resnetv2_101x1_bits(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_101x1_bits', pretrained=pretrained, + layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs) + + +@register_model +def resnetv2_101x3_bits(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_101x3_bits', pretrained=pretrained, + layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs) + + +@register_model +def resnetv2_152x2_bits(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_152x2_bits', pretrained=pretrained, + layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs) + + +@register_model +def resnetv2_152x4_bits(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_152x4_bits', pretrained=pretrained, + layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs) + diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 72f3a61a..9b96e04e 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -23,11 +23,13 @@ Hacked together by / Copyright 2020 Ross Wightman import torch import torch.nn as nn from functools import partial +from collections import OrderedDict from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained from .layers import DropPath, to_2tuple, trunc_normal_ from .resnet import resnet26d, resnet50d +from .resnetv2 import ResNetV2, StdConv2dSame from .registry import register_model @@ -43,14 +45,19 @@ def _cfg(url='', **kwargs): default_cfgs = { - # patch models + # patch models (my experiments) 'vit_small_patch16_224': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', ), + + # patch models (weights ported from official JAX impl) 'vit_base_patch16_224': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), ), + 'vit_base_patch32_224': _cfg( + url='', # no official model weights for this combo, only for in21k + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 'vit_base_patch16_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), @@ -60,15 +67,38 @@ default_cfgs = { 'vit_large_patch16_224': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_large_patch32_224': _cfg( + url='', # no official model weights for this combo, only for in21k + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 'vit_large_patch16_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 'vit_large_patch32_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), - 'vit_huge_patch16_224': _cfg(), - 'vit_huge_patch32_384': _cfg(input_size=(3, 384, 384)), - # hybrid models + + # patch models, imagenet21k (weights ported from official JAX impl) + 'vit_base_patch16_224_in21k': _cfg( + url='', + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_base_patch32_224_in21k': _cfg( + url='', + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_large_patch16_224_in21k': _cfg( + url='', + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_large_patch32_224_in21k': _cfg( + url='', + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_huge_patch14_224_in21k': _cfg( + url='', + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + + # hybrid models (weights ported from official JAX impl) + 'vit_base_resnet50_384': _cfg( + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + + # hybrid models (my experiments) 'vit_small_resnet26d_224': _cfg(), 'vit_small_resnet50d_s3_224': _cfg(), 'vit_base_resnet26d_224': _cfg(), @@ -184,20 +214,26 @@ class HybridEmbed(nn.Module): training = backbone.training if training: backbone.eval() - o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) + if isinstance(o, (list, tuple)): + o = o[-1] # last feature if backbone outputs list/tuple of features feature_size = o.shape[-2:] feature_dim = o.shape[1] backbone.train(training) else: feature_size = to_2tuple(feature_size) - feature_dim = self.backbone.feature_info.channels()[-1] + if hasattr(self.backbone, 'feature_info'): + feature_dim = self.backbone.feature_info.channels()[-1] + else: + feature_dim = self.backbone.num_features self.num_patches = feature_size[0] * feature_size[1] - self.proj = nn.Linear(feature_dim, embed_dim) + self.proj = nn.Conv2d(feature_dim, embed_dim, 1) def forward(self, x): - x = self.backbone(x)[-1] - x = x.flatten(2).transpose(1, 2) - x = self.proj(x) + x = self.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + x = self.proj(x).flatten(2).transpose(1, 2) return x @@ -205,8 +241,8 @@ class VisionTransformer(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., - drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm): + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, representation_size=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm): super().__init__() self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models @@ -231,9 +267,14 @@ class VisionTransformer(nn.Module): for i in range(depth)]) self.norm = norm_layer(embed_dim) - # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here - #self.repr = nn.Linear(embed_dim, representation_size) - #self.repr_act = nn.Tanh() + # Representation layer + if representation_size: + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(embed_dim, representation_size)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() # Classifier head self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() @@ -279,6 +320,7 @@ class VisionTransformer(nn.Module): def forward(self, x): x = self.forward_features(x) + x = self.pre_logits(x) x = self.head(x) return x @@ -318,6 +360,17 @@ def vit_base_patch16_224(pretrained=False, **kwargs): return model +@register_model +def vit_base_patch32_224(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=224, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_base_patch32_224'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model + + @register_model def vit_base_patch16_384(pretrained=False, **kwargs): model = VisionTransformer( @@ -351,6 +404,17 @@ def vit_large_patch16_224(pretrained=False, **kwargs): return model +@register_model +def vit_large_patch32_224(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=224, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_large_patch32_224'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model + + @register_model def vit_large_patch16_384(pretrained=False, **kwargs): model = VisionTransformer( @@ -374,17 +438,72 @@ def vit_large_patch32_384(pretrained=False, **kwargs): @register_model -def vit_huge_patch16_224(pretrained=False, **kwargs): - model = VisionTransformer(patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs) - model.default_cfg = default_cfgs['vit_huge_patch16_224'] +def vit_base_patch16_224_in21k(pretrained=False, **kwargs): + model = VisionTransformer( + patch_size=16, num_classes=21843, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_base_patch16_224_in21k'] + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) + return model + + +@register_model +def vit_base_patch32_224_in21k(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=224, num_classes=21843, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_base_patch32_224_in21k'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model + + +@register_model +def vit_large_patch16_224_in21k(pretrained=False, **kwargs): + model = VisionTransformer( + patch_size=16, num_classes=21843, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_large_patch16_224_in21k'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) return model @register_model -def vit_huge_patch32_384(pretrained=False, **kwargs): +def vit_large_patch32_224_in21k(pretrained=False, **kwargs): model = VisionTransformer( - img_size=384, patch_size=32, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs) - model.default_cfg = default_cfgs['vit_huge_patch32_384'] + img_size=224, num_classes=21843, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_large_patch32_224_in21k'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model + + +@register_model +def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=224, patch_size=14, num_classes=21843, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_huge_patch14_224_in21k'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model + + +@register_model +def vit_base_resnet50_384(pretrained=False, **kwargs): + # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head + backbone = ResNetV2( + layers=(3, 4, 9), preact=False, stem_type='same', conv_layer=StdConv2dSame, num_classes=0, global_pool='') + model = VisionTransformer( + img_size=384, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_base_resnet50_384'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) return model diff --git a/train.py b/train.py index ca406655..98e4ddd4 100755 --- a/train.py +++ b/train.py @@ -76,8 +76,8 @@ parser.add_argument('--resume', default='', type=str, metavar='PATH', help='Resume full model and optimizer state from checkpoint (default: none)') parser.add_argument('--no-resume-opt', action='store_true', default=False, help='prevent resume of optimizer state when resuming model') -parser.add_argument('--num-classes', type=int, default=1000, metavar='N', - help='number of label classes (default: 1000)') +parser.add_argument('--num-classes', type=int, default=None, metavar='N', + help='number of label classes (Model default if None)') parser.add_argument('--gp', default=None, type=str, metavar='POOL', help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') parser.add_argument('--img-size', type=int, default=None, metavar='N',