ResNetV2 pre-act and non-preact model, w/ BiT pretrained weights and support for ViT R50 model. Tweaks for in21k num_classes passing. More to do... tests failing.

pull/323/head
Ross Wightman 4 years ago
parent de6046e213
commit 231d04e91a

@ -16,6 +16,7 @@ from .regnet import *
from .res2net import *
from .resnest import *
from .resnet import *
from .resnetv2 import *
from .rexnet import *
from .selecsls import *
from .senet import *

@ -6,8 +6,6 @@ from .layers import set_layer_config
def create_model(
model_name,
pretrained=False,
num_classes=1000,
in_chans=3,
checkpoint_path='',
scriptable=None,
exportable=None,
@ -18,8 +16,6 @@ def create_model(
Args:
model_name (str): name of model to instantiate
pretrained (bool): load pretrained ImageNet-1k weights if true
num_classes (int): number of classes for final fully connected layer (default: 1000)
in_chans (int): number of input channels / colors (default: 3)
checkpoint_path (str): path of checkpoint to load after model is initialized
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
@ -30,7 +26,7 @@ def create_model(
global_pool (str): global pool type (default: 'avg')
**: other kwargs are model specific
"""
model_args = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
model_args = dict(pretrained=pretrained)
# Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3'])

@ -11,7 +11,7 @@ from typing import Callable
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from torch.hub import get_dir, load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .layers import Conv2dSame, Linear
@ -88,15 +88,70 @@ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None,
raise FileNotFoundError()
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True):
def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_hash=False):
r"""Loads a custom (read non .pth) weight file
Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
a passed in custom load fun, or the `load_pretrained` model member fn.
If the object is already present in `model_dir`, it's deserialized and returned.
The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
`hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
Args:
model: The instantiated model to load weights into
cfg (dict): Default pretrained model cfg
load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named
'laod_pretrained' on the model will be called if it exists
progress (bool, optional): whether or not to display a progress bar to stderr. Default: False
check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file. Default: False
"""
if cfg is None:
cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']:
_logger.warning("Pretrained model URL is invalid, using random initialization.")
_logger.warning("Pretrained model URL does not exist, using random initialization.")
return
url = cfg['url']
# Issue warning to move data if old env is set
if os.getenv('TORCH_MODEL_ZOO'):
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
if check_hash:
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
hash_prefix = r.group(1) if r else None
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
if load_fn is not None:
load_fn(model, cached_file)
elif hasattr(model, 'load_pretrained'):
model.load_pretrained(cached_file)
else:
_logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
if cfg is None:
cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']:
_logger.warning("Pretrained model URL does not exist, using random initialization.")
return
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu')
if filter_fn is not None:
state_dict = filter_fn(state_dict)
@ -269,6 +324,7 @@ def build_model_with_cfg(
feature_cfg: dict = None,
pretrained_strict: bool = True,
pretrained_filter_fn: Callable = None,
pretrained_custom_load: bool = False,
**kwargs):
pruned = kwargs.pop('pruned', False)
features = False
@ -289,6 +345,9 @@ def build_model_with_cfg(
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained:
if pretrained_custom_load:
load_custom_pretrained(model)
else:
load_pretrained(
model,
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),

@ -7,7 +7,7 @@ from .classifier import ClassifierHead, create_classifier
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
set_layer_config
from .conv2d_same import Conv2dSame
from .conv2d_same import Conv2dSame, conv2d_same
from .conv_bn_act import ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn
from .create_attn import create_attn
@ -20,8 +20,8 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple
from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .norm_act import BatchNormAct2d
from .padding import get_padding
from .norm_act import BatchNormAct2d, GroupNormAct
from .padding import get_padding, get_same_padding, pad_same
from .pool2d_same import AvgPool2dSame, create_pool2d
from .se import SEModule
from .selective_kernel import SelectiveKernelConv

@ -9,31 +9,43 @@ from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .linear import Linear
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
flatten = not use_conv # flatten when we use a Linear layer after pooling
def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
if not pool_type:
assert num_classes == 0 or use_conv,\
'Pooling can only be disabled if classifier is also removed or conv classifier is used'
flatten = False # disable flattening if pooling is pass-through (no pooling)
global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten)
flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)
num_pooled_features = num_features * global_pool.feat_mult()
return global_pool, num_pooled_features
def _create_fc(num_features, num_classes, pool_type='avg', use_conv=False):
if num_classes <= 0:
fc = nn.Identity() # pass-through (no classifier)
elif use_conv:
fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True)
fc = nn.Conv2d(num_features, num_classes, 1, bias=True)
else:
# NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue
fc = Linear(num_pooled_features, num_classes, bias=True)
fc = Linear(num_features, num_classes, bias=True)
return fc
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv)
fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
return global_pool, fc
class ClassifierHead(nn.Module):
"""Classifier head w/ configurable global pooling and dropout."""
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.):
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate
self.global_pool, self.fc = create_classifier(in_chs, num_classes, pool_type=pool_type)
self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
self.flatten_after_fc = use_conv and pool_type
def forward(self, x):
x = self.global_pool(x)

@ -68,8 +68,8 @@ class BatchNormAct2d(nn.BatchNorm2d):
class GroupNormAct(nn.GroupNorm):
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True,
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
def __init__(self, num_channels, num_groups, eps=1e-5, affine=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):
super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine)
if isinstance(act_layer, str):

@ -403,7 +403,7 @@ class ReductionCell1(nn.Module):
class NASNetALarge(nn.Module):
"""NASNetALarge (6 @ 4032) """
def __init__(self, num_classes=1000, in_chans=1, stem_size=96, channel_multiplier=2,
def __init__(self, num_classes=1000, in_chans=3, stem_size=96, channel_multiplier=2,
num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'):
super(NASNetALarge, self).__init__()
self.num_classes = num_classes

@ -0,0 +1,578 @@
"""Pre-Activation ResNet v2 with GroupNorm and Weight Standardization.
A PyTorch implementation of ResNetV2 adapted from the Google Big-Transfoer (BiT) source code
at https://github.com/google-research/big_transfer to match timm interfaces. The BiT weights have
been included here as pretrained models from their original .NPZ checkpoints.
Additionally, supports non pre-activation bottleneck for use as a backbone for Vision Transfomers (ViT) and
extra padding support to allow porting of official Hybrid ResNet pretrained weights from
https://github.com/google-research/vision_transformer
Thanks to the Google team for the above two repositories and associated papers.
Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020.
"""
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict # pylint: disable=g-importing-member
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg
from .registry import register_model
from .layers import get_padding, GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, conv2d_same
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 480, 480), 'pool_size': (7, 7),
'crop_pct': 1.0, 'interpolation': 'bilinear',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = {
# pretrained on imagenet21k, finetuned on imagenet1k
'resnetv2_50x1_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz'),
'resnetv2_50x3_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz'),
'resnetv2_101x1_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz'),
'resnetv2_101x3_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz'),
'resnetv2_152x2_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz'),
'resnetv2_152x4_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz'),
# trained on imagenet-21k
'resnetv2_50x1_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz',
num_classes=21843),
'resnetv2_50x3_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz',
num_classes=21843),
'resnetv2_101x1_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz',
num_classes=21843),
'resnetv2_101x3_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz',
num_classes=21843),
'resnetv2_152x2_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz',
num_classes=21843),
'resnetv2_152x4_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz',
num_classes=21843),
# trained on imagenet-1k
'resnetv2_50x1_bits': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-S-R50x1-ILSVRC2012.npz'),
'resnetv2_50x3_bits': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-S-R50x3-ILSVRC2012.npz'),
'resnetv2_101x1_bits': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-S-R101x3-ILSVRC2012.npz'),
'resnetv2_101x3_bits': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-S-R101x3-ILSVRC2012.npz'),
'resnetv2_152x2_bits': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-S-R152x2-ILSVRC2012.npz'),
'resnetv2_152x4_bits': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-S-R152x4-ILSVRC2012.npz'),
}
def make_div(v, divisor=8):
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class StdConv2d(nn.Conv2d):
def __init__(
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5):
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channel, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=bias, groups=groups)
self.eps = eps
def forward(self, x):
w = self.weight
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
w = (w - m) / (torch.sqrt(v) + self.eps)
x = F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
class StdConv2dSame(nn.Conv2d):
"""StdConv2d w/ TF compatible SAME padding. Used for ViT Hybrid model.
"""
def __init__(
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5):
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channel, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=bias, groups=groups)
self.eps = eps
def forward(self, x):
w = self.weight
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
w = (w - m) / (torch.sqrt(v) + self.eps)
x = conv2d_same(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
def tf2th(conv_weights):
"""Possibly convert HWIO to OIHW."""
if conv_weights.ndim == 4:
conv_weights = conv_weights.transpose([3, 2, 0, 1])
return torch.from_numpy(conv_weights)
class PreActBottleneck(nn.Module):
"""Pre-activation (v2) bottleneck block.
Follows the implementation of "Identity Mappings in Deep Residual Networks":
https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua
Except it puts the stride on 3x3 conv when available.
"""
def __init__(
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
super().__init__()
first_dilation = first_dilation or dilation
conv_layer = conv_layer or StdConv2d
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
out_chs = out_chs or in_chs
mid_chs = make_div(out_chs * bottle_ratio)
if proj_layer is not None:
self.downsample = proj_layer(
in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, preact=True,
conv_layer=conv_layer, norm_layer=norm_layer)
else:
self.downsample = None
self.norm1 = norm_layer(in_chs)
self.conv1 = conv_layer(in_chs, mid_chs, 1)
self.norm2 = norm_layer(mid_chs)
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
self.norm3 = norm_layer(mid_chs)
self.conv3 = conv_layer(mid_chs, out_chs, 1)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
def forward(self, x):
x_preact = self.norm1(x)
# shortcut branch
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x_preact)
# residual branch
x = self.conv1(x_preact)
x = self.conv2(self.norm2(x))
x = self.conv3(self.norm3(x))
x = self.drop_path(x)
return x + shortcut
class Bottleneck(nn.Module):
"""Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT.
"""
def __init__(
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
super().__init__()
first_dilation = first_dilation or dilation
act_layer = act_layer or nn.ReLU
conv_layer = conv_layer or StdConv2d
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
out_chs = out_chs or in_chs
mid_chs = make_div(out_chs * bottle_ratio)
if proj_layer is not None:
self.downsample = proj_layer(
in_chs, out_chs, stride=stride, dilation=dilation, preact=False,
conv_layer=conv_layer, norm_layer=norm_layer)
else:
self.downsample = None
self.conv1 = conv_layer(in_chs, mid_chs, 1)
self.norm1 = norm_layer(mid_chs)
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
self.norm2 = norm_layer(mid_chs)
self.conv3 = conv_layer(mid_chs, out_chs, 1)
self.norm3 = norm_layer(out_chs, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.act3 = act_layer(inplace=True)
def forward(self, x):
# shortcut branch
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x)
# residual
x = self.conv1(x)
x = self.norm1(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.conv3(x)
x = self.norm3(x)
x = self.act3(x + shortcut)
return x
class DownsampleConv(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True,
conv_layer=None, norm_layer=None):
super(DownsampleConv, self).__init__()
self.conv = conv_layer(in_chs, out_chs, 1, stride=stride)
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
def forward(self, x):
return self.norm(self.conv(x))
class DownsampleAvg(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None,
preact=True, conv_layer=None, norm_layer=None):
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
super(DownsampleAvg, self).__init__()
avg_stride = stride if dilation == 1 else 1
if stride > 1 or dilation > 1:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
else:
self.pool = nn.Identity()
self.conv = conv_layer(in_chs, out_chs, 1, stride=1)
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
def forward(self, x):
return self.norm(self.conv(self.pool(x)))
class ResNetStage(nn.Module):
"""ResNet Stage."""
def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1,
avg_down=False, block_dpr=None, block_fn=PreActBottleneck,
act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs):
super(ResNetStage, self).__init__()
first_dilation = 1 if dilation in (1, 2) else 2
layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer)
proj_layer = DownsampleAvg if avg_down else DownsampleConv
prev_chs = in_chs
self.blocks = nn.Sequential()
for block_idx in range(depth):
drop_path_rate = block_dpr[block_idx] if block_dpr else 0.
stride = stride if block_idx == 0 else 1
self.blocks.add_module(str(block_idx), block_fn(
prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups,
first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate,
**layer_kwargs, **block_kwargs))
prev_chs = out_chs
first_dilation = dilation
proj_layer = None
def forward(self, x):
x = self.blocks(x)
return x
def create_stem(in_chs, out_chs, stem_type='', preact=True, conv_layer=None, norm_layer=None):
stem = OrderedDict()
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
# NOTE conv padding mode can be changed by overriding the conv_layer def
if 'deep' in stem_type:
# A 3 deep 3x3 conv stack as in ResNet V1D models
mid_chs = out_chs // 2
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1)
else:
# The usual 7x7 stem conv
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
if not preact:
stem['norm'] = norm_layer(out_chs)
if 'fixed' in stem_type:
# 'fixed' SAME padding approximation that is used in BiT models
stem['pad'] = nn.ConstantPad2d(1, 0)
stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
elif 'same' in stem_type:
# full, input size based 'SAME' padding, used in ViT Hybrid model
stem['pool'] = create_pool2d('max', kernel_size=3, stride=2, padding='same')
else:
# the usual PyTorch symmetric padding
stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
return nn.Sequential(stem)
class ResNetV2(nn.Module):
"""Implementation of Pre-activation (v2) ResNet mode.
"""
def __init__(self, layers, channels=(256, 512, 1024, 2048),
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
drop_rate=0., drop_path_rate=0.):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
wf = width_factor
self.feature_info = []
stem_chs = make_div(stem_chs * wf)
self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
if not preact:
self.feature_info.append(dict(num_chs=stem_chs, reduction=4, module='stem'))
prev_chs = stem_chs
curr_stride = 4
dilation = 1
block_fn = PreActBottleneck if preact else Bottleneck
self.stages = nn.Sequential()
for stage_idx, (d, c) in enumerate(zip(layers, channels)):
out_chs = make_div(c * wf)
stride = 1 if stage_idx == 0 else 2
if curr_stride >= output_stride:
dilation *= stride
stride = 1
if preact:
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}.norm1')]
stage = ResNetStage(
prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_fn=block_fn)
prev_chs = out_chs
curr_stride *= stride
if not preact:
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')]
self.stages.add_module(str(stage_idx), stage)
self.num_features = prev_chs
self.norm = norm_layer(self.num_features) if preact else nn.Identity()
if preact:
self.feature_info += [dict(num_chs=self.num_features, reduction=curr_stride, module=f'norm')]
self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
for n, m in self.named_modules():
if isinstance(m, nn.Linear) or ('.fc' in n and isinstance(m, nn.Conv2d)):
nn.init.normal_(m.weight, mean=0.0, std=0.01)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
if not self.head.global_pool.is_identity():
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
return x
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
import numpy as np
weights = np.load(checkpoint_path)
with torch.no_grad():
self.stem.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel']))
self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
for i, (sname, stage) in enumerate(self.stages.named_children()):
for j, (bname, block) in enumerate(stage.blocks.named_children()):
convname = 'standardized_conv2d'
block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/'
block.conv1.weight.copy_(tf2th(weights[f'{block_prefix}a/{convname}/kernel']))
block.conv2.weight.copy_(tf2th(weights[f'{block_prefix}b/{convname}/kernel']))
block.conv3.weight.copy_(tf2th(weights[f'{block_prefix}c/{convname}/kernel']))
block.norm1.weight.copy_(tf2th(weights[f'{block_prefix}a/group_norm/gamma']))
block.norm2.weight.copy_(tf2th(weights[f'{block_prefix}b/group_norm/gamma']))
block.norm3.weight.copy_(tf2th(weights[f'{block_prefix}c/group_norm/gamma']))
block.norm1.bias.copy_(tf2th(weights[f'{block_prefix}a/group_norm/beta']))
block.norm2.bias.copy_(tf2th(weights[f'{block_prefix}b/group_norm/beta']))
block.norm3.bias.copy_(tf2th(weights[f'{block_prefix}c/group_norm/beta']))
if block.downsample is not None:
w = weights[f'{block_prefix}a/proj/{convname}/kernel']
block.downsample.conv.weight.copy_(tf2th(w))
def _create_resnetv2(variant, pretrained=False, **kwargs):
# FIXME feature map extraction is not setup properly for pre-activation mode right now
return build_model_with_cfg(
ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_custom_load=True,
feature_cfg=dict(flatten_sequential=True), **kwargs)
@register_model
def resnetv2_50x1_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x1_bitm', pretrained=pretrained,
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
@register_model
def resnetv2_50x3_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x3_bitm', pretrained=pretrained,
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
@register_model
def resnetv2_101x1_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x1_bitm', pretrained=pretrained,
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
@register_model
def resnetv2_101x3_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x3_bitm', pretrained=pretrained,
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
@register_model
def resnetv2_152x2_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x2_bitm', pretrained=pretrained,
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
@register_model
def resnetv2_152x4_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x4_bitm', pretrained=pretrained,
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
@register_model
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x1_bitm', pretrained=pretrained,
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
@register_model
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x3_bitm', pretrained=pretrained,
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
@register_model
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x1_bitm', pretrained=pretrained,
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
@register_model
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x3_bitm', pretrained=pretrained,
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
@register_model
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x2_bitm', pretrained=pretrained,
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
@register_model
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x4_bitm', pretrained=pretrained,
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
@register_model
def resnetv2_50x1_bits(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x1_bits', pretrained=pretrained,
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
@register_model
def resnetv2_50x3_bits(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x3_bits', pretrained=pretrained,
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
@register_model
def resnetv2_101x1_bits(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x1_bits', pretrained=pretrained,
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
@register_model
def resnetv2_101x3_bits(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x3_bits', pretrained=pretrained,
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
@register_model
def resnetv2_152x2_bits(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x2_bits', pretrained=pretrained,
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
@register_model
def resnetv2_152x4_bits(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x4_bits', pretrained=pretrained,
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)

@ -23,11 +23,13 @@ Hacked together by / Copyright 2020 Ross Wightman
import torch
import torch.nn as nn
from functools import partial
from collections import OrderedDict
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .layers import DropPath, to_2tuple, trunc_normal_
from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, StdConv2dSame
from .registry import register_model
@ -43,14 +45,19 @@ def _cfg(url='', **kwargs):
default_cfgs = {
# patch models
# patch models (my experiments)
'vit_small_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
),
# patch models (weights ported from official JAX impl)
'vit_base_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
'vit_base_patch32_224': _cfg(
url='', # no official model weights for this combo, only for in21k
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_base_patch16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
@ -60,15 +67,38 @@ default_cfgs = {
'vit_large_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_large_patch32_224': _cfg(
url='', # no official model weights for this combo, only for in21k
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_large_patch16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
'vit_large_patch32_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
'vit_huge_patch16_224': _cfg(),
'vit_huge_patch32_384': _cfg(input_size=(3, 384, 384)),
# hybrid models
# patch models, imagenet21k (weights ported from official JAX impl)
'vit_base_patch16_224_in21k': _cfg(
url='',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_base_patch32_224_in21k': _cfg(
url='',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_large_patch16_224_in21k': _cfg(
url='',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_large_patch32_224_in21k': _cfg(
url='',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_huge_patch14_224_in21k': _cfg(
url='',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
# hybrid models (weights ported from official JAX impl)
'vit_base_resnet50_384': _cfg(
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
# hybrid models (my experiments)
'vit_small_resnet26d_224': _cfg(),
'vit_small_resnet50d_s3_224': _cfg(),
'vit_base_resnet26d_224': _cfg(),
@ -184,20 +214,26 @@ class HybridEmbed(nn.Module):
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
o = o[-1] # last feature if backbone outputs list/tuple of features
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
if hasattr(self.backbone, 'feature_info'):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Linear(feature_dim, embed_dim)
self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
def forward(self, x):
x = self.backbone(x)[-1]
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x).flatten(2).transpose(1, 2)
return x
@ -205,8 +241,8 @@ class VisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, representation_size=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
@ -231,9 +267,14 @@ class VisionTransformer(nn.Module):
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
#self.repr = nn.Linear(embed_dim, representation_size)
#self.repr_act = nn.Tanh()
# Representation layer
if representation_size:
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(embed_dim, representation_size)),
('act', nn.Tanh())
]))
else:
self.pre_logits = nn.Identity()
# Classifier head
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
@ -279,6 +320,7 @@ class VisionTransformer(nn.Module):
def forward(self, x):
x = self.forward_features(x)
x = self.pre_logits(x)
x = self.head(x)
return x
@ -318,6 +360,17 @@ def vit_base_patch16_224(pretrained=False, **kwargs):
return model
@register_model
def vit_base_patch32_224(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=224, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_base_patch32_224']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def vit_base_patch16_384(pretrained=False, **kwargs):
model = VisionTransformer(
@ -351,6 +404,17 @@ def vit_large_patch16_224(pretrained=False, **kwargs):
return model
@register_model
def vit_large_patch32_224(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=224, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_large_patch32_224']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def vit_large_patch16_384(pretrained=False, **kwargs):
model = VisionTransformer(
@ -374,17 +438,72 @@ def vit_large_patch32_384(pretrained=False, **kwargs):
@register_model
def vit_huge_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer(patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs)
model.default_cfg = default_cfgs['vit_huge_patch16_224']
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16, num_classes=21843, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_base_patch16_224_in21k']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
return model
@register_model
def vit_huge_patch32_384(pretrained=False, **kwargs):
def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=384, patch_size=32, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs)
model.default_cfg = default_cfgs['vit_huge_patch32_384']
img_size=224, num_classes=21843, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_base_patch32_224_in21k']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16, num_classes=21843, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_large_patch16_224_in21k']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=224, num_classes=21843, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_large_patch32_224_in21k']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=224, patch_size=14, num_classes=21843, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_huge_patch14_224_in21k']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def vit_base_resnet50_384(pretrained=False, **kwargs):
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
backbone = ResNetV2(
layers=(3, 4, 9), preact=False, stem_type='same', conv_layer=StdConv2dSame, num_classes=0, global_pool='')
model = VisionTransformer(
img_size=384, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_base_resnet50_384']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model

@ -76,8 +76,8 @@ parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('--num-classes', type=int, default=None, metavar='N',
help='number of label classes (Model default if None)')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--img-size', type=int, default=None, metavar='N',

Loading…
Cancel
Save