@ -35,9 +35,9 @@ import torch.nn as nn
from functools import partial
from timm . data import IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD
from . helpers import build_model_with_cfg
from . helpers import build_model_with_cfg , named_apply , adapt_input_conv
from . registry import register_model
from . layers import GroupNormAct , ClassifierHead , DropPath , AvgPool2dSame , create_pool2d , StdConv2d
from . layers import GroupNormAct , ClassifierHead , DropPath , AvgPool2dSame , create_pool2d , StdConv2d , create_conv2d
def _cfg ( url = ' ' , * * kwargs ) :
@ -86,20 +86,10 @@ default_cfgs = {
url = ' https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz ' ,
num_classes = 21843 ) ,
# trained on imagenet-1k, NOTE not overly interesting set of weights, leaving disabled for now
# 'resnetv2_50x1_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R50x1.npz'),
# 'resnetv2_50x3_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R50x3.npz'),
# 'resnetv2_101x1_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
# 'resnetv2_101x3_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
# 'resnetv2_152x2_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R152x2.npz'),
# 'resnetv2_152x4_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R152x4.npz'),
' resnetv2_50 ' : _cfg (
input_size = ( 3 , 224 , 224 ) , crop_pct = 0.875 , interpolation = ' bicubic ' ) ,
' resnetv2_50d ' : _cfg (
input_size = ( 3 , 224 , 224 ) , crop_pct = 0.875 , interpolation = ' bicubic ' , first_conv = ' stem.conv1 ' ) ,
}
@ -111,13 +101,6 @@ def make_div(v, divisor=8):
return new_v
def tf2th ( conv_weights ) :
""" Possibly convert HWIO to OIHW. """
if conv_weights . ndim == 4 :
conv_weights = conv_weights . transpose ( [ 3 , 2 , 0 , 1 ] )
return torch . from_numpy ( conv_weights )
class PreActBottleneck ( nn . Module ) :
""" Pre-activation (v2) bottleneck block.
@ -152,6 +135,9 @@ class PreActBottleneck(nn.Module):
self . conv3 = conv_layer ( mid_chs , out_chs , 1 )
self . drop_path = DropPath ( drop_path_rate ) if drop_path_rate > 0 else nn . Identity ( )
def zero_init_last_bn ( self ) :
nn . init . zeros_ ( self . norm3 . weight )
def forward ( self , x ) :
x_preact = self . norm1 ( x )
@ -198,6 +184,9 @@ class Bottleneck(nn.Module):
self . drop_path = DropPath ( drop_path_rate ) if drop_path_rate > 0 else nn . Identity ( )
self . act3 = act_layer ( inplace = True )
def zero_init_last_bn ( self ) :
nn . init . zeros_ ( self . norm3 . weight )
def forward ( self , x ) :
# shortcut branch
shortcut = x
@ -276,7 +265,7 @@ class ResNetStage(nn.Module):
def create_resnetv2_stem (
in_chs , out_chs = 64 , stem_type = ' ' , preact = True ,
conv_layer = partial( StdConv2d, eps = 1e-8 ) , norm_layer = partial ( GroupNormAct , num_groups = 32 ) ) :
conv_layer = StdConv2d, norm_layer = partial ( GroupNormAct , num_groups = 32 ) ) :
stem = OrderedDict ( )
assert stem_type in ( ' ' , ' fixed ' , ' same ' , ' deep ' , ' deep_fixed ' , ' deep_same ' )
@ -285,14 +274,17 @@ def create_resnetv2_stem(
# A 3 deep 3x3 conv stack as in ResNet V1D models
mid_chs = out_chs / / 2
stem [ ' conv1 ' ] = conv_layer ( in_chs , mid_chs , kernel_size = 3 , stride = 2 )
stem [ ' norm1 ' ] = norm_layer ( mid_chs )
stem [ ' conv2 ' ] = conv_layer ( mid_chs , mid_chs , kernel_size = 3 , stride = 1 )
stem [ ' norm2 ' ] = norm_layer ( mid_chs )
stem [ ' conv3 ' ] = conv_layer ( mid_chs , out_chs , kernel_size = 3 , stride = 1 )
if not preact :
stem [ ' norm3 ' ] = norm_layer ( out_chs )
else :
# The usual 7x7 stem conv
stem [ ' conv ' ] = conv_layer ( in_chs , out_chs , kernel_size = 7 , stride = 2 )
if not preact :
stem [ ' norm ' ] = norm_layer ( out_chs )
if not preact :
stem [ ' norm ' ] = norm_layer ( out_chs )
if ' fixed ' in stem_type :
# 'fixed' SAME padding approximation that is used in BiT models
@ -312,11 +304,12 @@ class ResNetV2(nn.Module):
""" Implementation of Pre-activation (v2) ResNet mode.
"""
def __init__ ( self , layers , channels = ( 256 , 512 , 1024 , 2048 ) ,
num_classes = 1000 , in_chans = 3 , global_pool = ' avg ' , output_stride = 32 ,
width_factor = 1 , stem_chs = 64 , stem_type = ' ' , avg_down = False , preact = True ,
act_layer = nn . ReLU , conv_layer = partial ( StdConv2d , eps = 1e-8 ) ,
norm_layer = partial ( GroupNormAct , num_groups = 32 ) , drop_rate = 0. , drop_path_rate = 0. ) :
def __init__ (
self , layers , channels = ( 256 , 512 , 1024 , 2048 ) ,
num_classes = 1000 , in_chans = 3 , global_pool = ' avg ' , output_stride = 32 ,
width_factor = 1 , stem_chs = 64 , stem_type = ' ' , avg_down = False , preact = True ,
act_layer = nn . ReLU , conv_layer = StdConv2d , norm_layer = partial ( GroupNormAct , num_groups = 32 ) ,
drop_rate = 0. , drop_path_rate = 0. , zero_init_last_bn = True ) :
super ( ) . __init__ ( )
self . num_classes = num_classes
self . drop_rate = drop_rate
@ -354,12 +347,14 @@ class ResNetV2(nn.Module):
self . head = ClassifierHead (
self . num_features , num_classes , pool_type = global_pool , drop_rate = self . drop_rate , use_conv = True )
for n , m in self . named_modules ( ) :
if isinstance ( m , nn . Linear ) or ( ' .fc ' in n and isinstance ( m , nn . Conv2d ) ) :
nn . init . normal_ ( m . weight , mean = 0.0 , std = 0.01 )
nn . init . zeros_ ( m . bias )
elif isinstance ( m , nn . Conv2d ) :
nn . init . kaiming_normal_ ( m . weight , mode = ' fan_out ' , nonlinearity = ' relu ' )
self . init_weights ( zero_init_last_bn = zero_init_last_bn )
def init_weights ( self , zero_init_last_bn = True ) :
named_apply ( partial ( _init_weights , zero_init_last_bn = zero_init_last_bn ) , self )
@torch.jit.ignore ( )
def load_pretrained ( self , checkpoint_path , prefix = ' resnet/ ' ) :
_load_weights ( self , checkpoint_path , prefix )
def get_classifier ( self ) :
return self . head . fc
@ -378,41 +373,59 @@ class ResNetV2(nn.Module):
def forward ( self , x ) :
x = self . forward_features ( x )
x = self . head ( x )
if not self . head . global_pool . is_identity ( ) :
x = x . flatten ( 1 ) # conv classifier, flatten if pooling isn't pass-through (disabled)
return x
def load_pretrained ( self , checkpoint_path , prefix = ' resnet/ ' ) :
import numpy as np
weights = np . load ( checkpoint_path )
with torch . no_grad ( ) :
stem_conv_w = tf2th ( weights [ f ' { prefix } root_block/standardized_conv2d/kernel ' ] )
if self . stem . conv . weight . shape [ 1 ] == 1 :
self . stem . conv . weight . copy_ ( stem_conv_w . sum ( dim = 1 , keepdim = True ) )
# FIXME handle > 3 in_chans?
else :
self . stem . conv . weight . copy_ ( stem_conv_w )
self . norm . weight . copy_ ( tf2th ( weights [ f ' { prefix } group_norm/gamma ' ] ) )
self . norm . bias . copy_ ( tf2th ( weights [ f ' { prefix } group_norm/beta ' ] ) )
if self . head . fc . weight . shape [ 0 ] == weights [ f ' { prefix } head/conv2d/kernel ' ] . shape [ - 1 ] :
self . head . fc . weight . copy_ ( tf2th ( weights [ f ' { prefix } head/conv2d/kernel ' ] ) )
self . head . fc . bias . copy_ ( tf2th ( weights [ f ' { prefix } head/conv2d/bias ' ] ) )
for i , ( sname , stage ) in enumerate ( self . stages . named_children ( ) ) :
for j , ( bname , block ) in enumerate ( stage . blocks . named_children ( ) ) :
convname = ' standardized_conv2d '
block_prefix = f ' { prefix } block { i + 1 } /unit { j + 1 : 02d } / '
block . conv1 . weight . copy_ ( tf2th ( weights [ f ' { block_prefix } a/ { convname } /kernel ' ] ) )
block . conv2 . weight . copy_ ( tf2th ( weights [ f ' { block_prefix } b/ { convname } /kernel ' ] ) )
block . conv3 . weight . copy_ ( tf2th ( weights [ f ' { block_prefix } c/ { convname } /kernel ' ] ) )
block . norm1 . weight . copy_ ( tf2th ( weights [ f ' { block_prefix } a/group_norm/gamma ' ] ) )
block . norm2 . weight . copy_ ( tf2th ( weights [ f ' { block_prefix } b/group_norm/gamma ' ] ) )
block . norm3 . weight . copy_ ( tf2th ( weights [ f ' { block_prefix } c/group_norm/gamma ' ] ) )
block . norm1 . bias . copy_ ( tf2th ( weights [ f ' { block_prefix } a/group_norm/beta ' ] ) )
block . norm2 . bias . copy_ ( tf2th ( weights [ f ' { block_prefix } b/group_norm/beta ' ] ) )
block . norm3 . bias . copy_ ( tf2th ( weights [ f ' { block_prefix } c/group_norm/beta ' ] ) )
if block . downsample is not None :
w = weights [ f ' { block_prefix } a/proj/ { convname } /kernel ' ]
block . downsample . conv . weight . copy_ ( tf2th ( w ) )
def _init_weights ( module : nn . Module , name : str = ' ' , zero_init_last_bn = True ) :
if isinstance ( module , nn . Linear ) or ( ' head.fc ' in name and isinstance ( module , nn . Conv2d ) ) :
nn . init . normal_ ( module . weight , mean = 0.0 , std = 0.01 )
nn . init . zeros_ ( module . bias )
elif isinstance ( module , nn . Conv2d ) :
nn . init . kaiming_normal_ ( module . weight , mode = ' fan_out ' , nonlinearity = ' relu ' )
if module . bias is not None :
nn . init . zeros_ ( module . bias )
elif isinstance ( module , ( nn . BatchNorm2d , nn . LayerNorm , nn . GroupNorm ) ) :
nn . init . ones_ ( module . weight )
nn . init . zeros_ ( module . bias )
elif zero_init_last_bn and hasattr ( module , ' zero_init_last_bn ' ) :
module . zero_init_last_bn ( )
@torch.no_grad ( )
def _load_weights ( model : nn . Module , checkpoint_path : str , prefix : str = ' resnet/ ' ) :
import numpy as np
def t2p ( conv_weights ) :
""" Possibly convert HWIO to OIHW. """
if conv_weights . ndim == 4 :
conv_weights = conv_weights . transpose ( [ 3 , 2 , 0 , 1 ] )
return torch . from_numpy ( conv_weights )
weights = np . load ( checkpoint_path )
stem_conv_w = adapt_input_conv (
model . stem . conv . weight . shape [ 1 ] , t2p ( weights [ f ' { prefix } root_block/standardized_conv2d/kernel ' ] ) )
model . stem . conv . weight . copy_ ( stem_conv_w )
model . norm . weight . copy_ ( t2p ( weights [ f ' { prefix } group_norm/gamma ' ] ) )
model . norm . bias . copy_ ( t2p ( weights [ f ' { prefix } group_norm/beta ' ] ) )
if model . head . fc . weight . shape [ 0 ] == weights [ f ' { prefix } head/conv2d/kernel ' ] . shape [ - 1 ] :
model . head . fc . weight . copy_ ( t2p ( weights [ f ' { prefix } head/conv2d/kernel ' ] ) )
model . head . fc . bias . copy_ ( t2p ( weights [ f ' { prefix } head/conv2d/bias ' ] ) )
for i , ( sname , stage ) in enumerate ( model . stages . named_children ( ) ) :
for j , ( bname , block ) in enumerate ( stage . blocks . named_children ( ) ) :
cname = ' standardized_conv2d '
block_prefix = f ' { prefix } block { i + 1 } /unit { j + 1 : 02d } / '
block . conv1 . weight . copy_ ( t2p ( weights [ f ' { block_prefix } a/ { cname } /kernel ' ] ) )
block . conv2 . weight . copy_ ( t2p ( weights [ f ' { block_prefix } b/ { cname } /kernel ' ] ) )
block . conv3 . weight . copy_ ( t2p ( weights [ f ' { block_prefix } c/ { cname } /kernel ' ] ) )
block . norm1 . weight . copy_ ( t2p ( weights [ f ' { block_prefix } a/group_norm/gamma ' ] ) )
block . norm2 . weight . copy_ ( t2p ( weights [ f ' { block_prefix } b/group_norm/gamma ' ] ) )
block . norm3 . weight . copy_ ( t2p ( weights [ f ' { block_prefix } c/group_norm/gamma ' ] ) )
block . norm1 . bias . copy_ ( t2p ( weights [ f ' { block_prefix } a/group_norm/beta ' ] ) )
block . norm2 . bias . copy_ ( t2p ( weights [ f ' { block_prefix } b/group_norm/beta ' ] ) )
block . norm3 . bias . copy_ ( t2p ( weights [ f ' { block_prefix } c/group_norm/beta ' ] ) )
if block . downsample is not None :
w = weights [ f ' { block_prefix } a/proj/ { cname } /kernel ' ]
block . downsample . conv . weight . copy_ ( t2p ( w ) )
def _create_resnetv2 ( variant , pretrained = False , * * kwargs ) :
@ -425,130 +438,99 @@ def _create_resnetv2(variant, pretrained=False, **kwargs):
* * kwargs )
def _create_resnetv2_bit ( variant , pretrained = False , * * kwargs ) :
return _create_resnetv2 (
variant , pretrained = pretrained , stem_type = ' fixed ' , conv_layer = partial ( StdConv2d , eps = 1e-8 ) , * * kwargs )
@register_model
def resnetv2_50x1_bitm ( pretrained = False , * * kwargs ) :
return _create_resnetv2 (
' resnetv2_50x1_bitm ' , pretrained = pretrained ,
layers = [ 3 , 4 , 6 , 3 ] , width_factor = 1 , stem_type = ' fixed ' , * * kwargs )
return _create_resnetv2_bit (
' resnetv2_50x1_bitm ' , pretrained = pretrained , layers = [ 3 , 4 , 6 , 3 ] , width_factor = 1 , * * kwargs )
@register_model
def resnetv2_50x3_bitm ( pretrained = False , * * kwargs ) :
return _create_resnetv2 (
' resnetv2_50x3_bitm ' , pretrained = pretrained ,
layers = [ 3 , 4 , 6 , 3 ] , width_factor = 3 , stem_type = ' fixed ' , * * kwargs )
return _create_resnetv2_bit (
' resnetv2_50x3_bitm ' , pretrained = pretrained , layers = [ 3 , 4 , 6 , 3 ] , width_factor = 3 , * * kwargs )
@register_model
def resnetv2_101x1_bitm ( pretrained = False , * * kwargs ) :
return _create_resnetv2 (
' resnetv2_101x1_bitm ' , pretrained = pretrained ,
layers = [ 3 , 4 , 23 , 3 ] , width_factor = 1 , stem_type = ' fixed ' , * * kwargs )
return _create_resnetv2_bit (
' resnetv2_101x1_bitm ' , pretrained = pretrained , layers = [ 3 , 4 , 23 , 3 ] , width_factor = 1 , * * kwargs )
@register_model
def resnetv2_101x3_bitm ( pretrained = False , * * kwargs ) :
return _create_resnetv2 (
' resnetv2_101x3_bitm ' , pretrained = pretrained ,
layers = [ 3 , 4 , 23 , 3 ] , width_factor = 3 , stem_type = ' fixed ' , * * kwargs )
return _create_resnetv2_bit (
' resnetv2_101x3_bitm ' , pretrained = pretrained , layers = [ 3 , 4 , 23 , 3 ] , width_factor = 3 , * * kwargs )
@register_model
def resnetv2_152x2_bitm ( pretrained = False , * * kwargs ) :
return _create_resnetv2 (
' resnetv2_152x2_bitm ' , pretrained = pretrained ,
layers = [ 3 , 8 , 36 , 3 ] , width_factor = 2 , stem_type = ' fixed ' , * * kwargs )
return _create_resnetv2_bit (
' resnetv2_152x2_bitm ' , pretrained = pretrained , layers = [ 3 , 8 , 36 , 3 ] , width_factor = 2 , * * kwargs )
@register_model
def resnetv2_152x4_bitm ( pretrained = False , * * kwargs ) :
return _create_resnetv2 (
' resnetv2_152x4_bitm ' , pretrained = pretrained ,
layers = [ 3 , 8 , 36 , 3 ] , width_factor = 4 , stem_type = ' fixed ' , * * kwargs )
return _create_resnetv2_bit (
' resnetv2_152x4_bitm ' , pretrained = pretrained , layers = [ 3 , 8 , 36 , 3 ] , width_factor = 4 , * * kwargs )
@register_model
def resnetv2_50x1_bitm_in21k ( pretrained = False , * * kwargs ) :
return _create_resnetv2 (
return _create_resnetv2 _bit (
' resnetv2_50x1_bitm_in21k ' , pretrained = pretrained , num_classes = kwargs . pop ( ' num_classes ' , 21843 ) ,
layers = [ 3 , 4 , 6 , 3 ] , width_factor = 1 , stem_type = ' fixed ' , * * kwargs )
layers = [ 3 , 4 , 6 , 3 ] , width_factor = 1 , * * kwargs )
@register_model
def resnetv2_50x3_bitm_in21k ( pretrained = False , * * kwargs ) :
return _create_resnetv2 (
return _create_resnetv2 _bit (
' resnetv2_50x3_bitm_in21k ' , pretrained = pretrained , num_classes = kwargs . pop ( ' num_classes ' , 21843 ) ,
layers = [ 3 , 4 , 6 , 3 ] , width_factor = 3 , stem_type = ' fixed ' , * * kwargs )
layers = [ 3 , 4 , 6 , 3 ] , width_factor = 3 , * * kwargs )
@register_model
def resnetv2_101x1_bitm_in21k ( pretrained = False , * * kwargs ) :
return _create_resnetv2 (
' resnetv2_101x1_bitm_in21k ' , pretrained = pretrained , num_classes = kwargs . pop ( ' num_classes ' , 21843 ) ,
layers = [ 3 , 4 , 23 , 3 ] , width_factor = 1 , stem_type = ' fixed ' , * * kwargs )
layers = [ 3 , 4 , 23 , 3 ] , width_factor = 1 , * * kwargs )
@register_model
def resnetv2_101x3_bitm_in21k ( pretrained = False , * * kwargs ) :
return _create_resnetv2 (
return _create_resnetv2 _bit (
' resnetv2_101x3_bitm_in21k ' , pretrained = pretrained , num_classes = kwargs . pop ( ' num_classes ' , 21843 ) ,
layers = [ 3 , 4 , 23 , 3 ] , width_factor = 3 , stem_type = ' fixed ' , * * kwargs )
layers = [ 3 , 4 , 23 , 3 ] , width_factor = 3 , * * kwargs )
@register_model
def resnetv2_152x2_bitm_in21k ( pretrained = False , * * kwargs ) :
return _create_resnetv2 (
return _create_resnetv2 _bit (
' resnetv2_152x2_bitm_in21k ' , pretrained = pretrained , num_classes = kwargs . pop ( ' num_classes ' , 21843 ) ,
layers = [ 3 , 8 , 36 , 3 ] , width_factor = 2 , stem_type = ' fixed ' , * * kwargs )
layers = [ 3 , 8 , 36 , 3 ] , width_factor = 2 , * * kwargs )
@register_model
def resnetv2_152x4_bitm_in21k ( pretrained = False , * * kwargs ) :
return _create_resnetv2 (
return _create_resnetv2 _bit (
' resnetv2_152x4_bitm_in21k ' , pretrained = pretrained , num_classes = kwargs . pop ( ' num_classes ' , 21843 ) ,
layers = [ 3 , 8 , 36 , 3 ] , width_factor = 4 , stem_type = ' fixed ' , * * kwargs )
layers = [ 3 , 8 , 36 , 3 ] , width_factor = 4 , * * kwargs )
# NOTE the 'S' versions of the model weights arent as interesting as original 21k or transfer to 1K M.
@register_model
def resnetv2_50 ( pretrained = False , * * kwargs ) :
return _create_resnetv2 (
' resnetv2_50 ' , pretrained = pretrained ,
layers = [ 3 , 4 , 6 , 3 ] , conv_layer = create_conv2d , norm_layer = nn . BatchNorm2d , * * kwargs )
# @register_model
# def resnetv2_50x1_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_50x1_bits', pretrained=pretrained,
# layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_50x3_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_50x3_bits', pretrained=pretrained,
# layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_101x1_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_101x1_bits', pretrained=pretrained,
# layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_101x3_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_101x3_bits', pretrained=pretrained,
# layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_152x2_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_152x2_bits', pretrained=pretrained,
# layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_152x4_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_152x4_bits', pretrained=pretrained,
# layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
#
@register_model
def resnetv2_50d ( pretrained = False , * * kwargs ) :
return _create_resnetv2 (
' resnetv2_50d ' , pretrained = pretrained ,
layers = [ 3 , 4 , 6 , 3 ] , conv_layer = create_conv2d , norm_layer = nn . BatchNorm2d ,
stem_type = ' deep ' , avg_down = True , * * kwargs )