@ -7,13 +7,12 @@ ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered ste
"""
import math
import torch
import torch . nn as nn
import torch . nn . functional as F
from . registry import register_model
from . helpers import load_pretrained
from . layers import EcaModule, SelectAdaptivePool2d, DropBlock2d , DropPath
from . layers import SelectAdaptivePool2d, DropBlock2d , DropPath , AvgPool2dSame , create_attn
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
@ -103,7 +102,8 @@ default_cfgs = {
' ecaresnext26tn_32x4d ' : _cfg (
url = ' ' ,
interpolation = ' bicubic ' ) ,
' ecaresnet18 ' : _cfg ( ) ,
' ecaresnet50 ' : _cfg ( ) ,
}
@ -112,32 +112,12 @@ def get_padding(kernel_size, stride, dilation=1):
return padding
class SEModule ( nn . Module ) :
def __init__ ( self , channels , reduction_channels ) :
super ( SEModule , self ) . __init__ ( )
self . avg_pool = nn . AdaptiveAvgPool2d ( 1 )
self . fc1 = nn . Conv2d (
channels , reduction_channels , kernel_size = 1 , padding = 0 , bias = True )
self . relu = nn . ReLU ( inplace = True )
self . fc2 = nn . Conv2d (
reduction_channels , channels , kernel_size = 1 , padding = 0 , bias = True )
def forward ( self , x ) :
x_se = self . avg_pool ( x )
x_se = self . fc1 ( x_se )
x_se = self . relu ( x_se )
x_se = self . fc2 ( x_se )
return x * x_se . sigmoid ( )
class BasicBlock ( nn . Module ) :
__constants__ = [ ' se ' , ' downsample ' ] # for pre 1.4 torchscript compat
expansion = 1
def __init__ ( self , inplanes , planes , stride = 1 , downsample = None , cardinality = 1 , base_width = 64 , use_se = False ,
def __init__ ( self , inplanes , planes , stride = 1 , downsample = None , cardinality = 1 , base_width = 64 ,
reduce_first = 1 , dilation = 1 , first_dilation = None , act_layer = nn . ReLU , norm_layer = nn . BatchNorm2d ,
drop_block= None , drop_path = None ) :
attn_layer = None , drop_block = None , drop_path = None ) :
super ( BasicBlock , self ) . __init__ ( )
assert cardinality == 1 , ' BasicBlock only supports cardinality of 1 '
@ -155,7 +135,7 @@ class BasicBlock(nn.Module):
first_planes , outplanes , kernel_size = 3 , padding = dilation , dilation = dilation , bias = False )
self . bn2 = norm_layer ( outplanes )
self . se = SEModule( outplanes , planes / / 4 ) if use_se else None
self . se = create_attn( attn_layer , outplanes )
self . act2 = act_layer ( inplace = True )
self . downsample = downsample
@ -199,9 +179,9 @@ class Bottleneck(nn.Module):
__constants__ = [ ' se ' , ' downsample ' ] # for pre 1.4 torchscript compat
expansion = 4
def __init__ ( self , inplanes , planes , stride = 1 , downsample = None , cardinality = 1 , base_width = 64 , use_se = False ,
def __init__ ( self , inplanes , planes , stride = 1 , downsample = None , cardinality = 1 , base_width = 64 ,
reduce_first = 1 , dilation = 1 , first_dilation = None , act_layer = nn . ReLU , norm_layer = nn . BatchNorm2d ,
drop_block= None , drop_path = None ) :
attn_layer= None , drop_block= None , drop_path = None ) :
super ( Bottleneck , self ) . __init__ ( )
width = int ( math . floor ( planes * ( base_width / 64 ) ) * cardinality )
@ -220,7 +200,7 @@ class Bottleneck(nn.Module):
self . conv3 = nn . Conv2d ( width , outplanes , kernel_size = 1 , bias = False )
self . bn3 = norm_layer ( outplanes )
self . se = SEModule( outplanes , planes / / 4 ) if use_se else None
self . se = create_attn( attn_layer , outplanes )
self . act3 = act_layer ( inplace = True )
self . downsample = downsample
@ -266,6 +246,37 @@ class Bottleneck(nn.Module):
return x
def downsample_conv (
in_channels , out_channels , kernel_size , stride = 1 , dilation = 1 , first_dilation = None , norm_layer = None ) :
norm_layer = norm_layer or nn . BatchNorm2d
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
first_dilation = ( first_dilation or dilation ) if kernel_size > 1 else 1
p = get_padding ( kernel_size , stride , first_dilation )
return nn . Sequential ( * [
nn . Conv2d (
in_channels , out_channels , kernel_size , stride = stride , padding = p , dilation = first_dilation , bias = False ) ,
norm_layer ( out_channels )
] )
def downsample_avg (
in_channels , out_channels , kernel_size , stride = 1 , dilation = 1 , first_dilation = None , norm_layer = None ) :
norm_layer = norm_layer or nn . BatchNorm2d
avg_stride = stride if dilation == 1 else 1
if stride == 1 and dilation == 1 :
pool = nn . Identity ( )
else :
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn . AvgPool2d
pool = avg_pool_fn ( 2 , avg_stride , ceil_mode = True , count_include_pad = False )
return nn . Sequential ( * [
pool ,
nn . Conv2d ( in_channels , out_channels , 1 , stride = 1 , padding = 0 , bias = False ) ,
norm_layer ( out_channels )
] )
class ResNet ( nn . Module ) :
""" ResNet / ResNeXt / SE-ResNeXt / SE-Net
@ -307,8 +318,6 @@ class ResNet(nn.Module):
Number of classification classes .
in_chans : int , default 3
Number of input ( color ) channels .
use_se : bool , default False
Enable Squeeze - Excitation module in blocks
cardinality : int , default 1
Number of convolution groups for 3 x3 conv in Bottleneck .
base_width : int , default 64
@ -337,7 +346,7 @@ class ResNet(nn.Module):
global_pool : str , default ' avg '
Global pooling type . One of ' avg ' , ' max ' , ' avgmax ' , ' catavgmax '
"""
def __init__ ( self , block , layers , num_classes = 1000 , in_chans = 3 , use_se = False , use_eca = False ,
def __init__ ( self , block , layers , num_classes = 1000 , in_chans = 3 ,
cardinality = 1 , base_width = 64 , stem_width = 64 , stem_type = ' ' ,
block_reduce_first = 1 , down_kernel_size = 1 , avg_down = False , output_stride = 32 ,
act_layer = nn . ReLU , norm_layer = nn . BatchNorm2d , drop_rate = 0.0 , drop_path_rate = 0. ,
@ -385,14 +394,14 @@ class ResNet(nn.Module):
dilations [ 2 : 4 ] = [ 2 , 4 ]
else :
assert output_stride == 32
l l args = list ( zip ( channels , layers , strides , dilations ) )
l kwargs = dict (
use_se= use_se , reduce_first= block_reduce_first , act_layer = act_layer , norm_layer = norm_layer ,
l ayer_ args = list ( zip ( channels , layers , strides , dilations ) )
l ayer_ kwargs = dict (
reduce_first= block_reduce_first , act_layer = act_layer , norm_layer = norm_layer ,
avg_down = avg_down , down_kernel_size = down_kernel_size , drop_path = dp , * * block_args )
self . layer1 = self . _make_layer ( block , * l l args[ 0 ] , * * l kwargs)
self . layer2 = self . _make_layer ( block , * l l args[ 1 ] , * * l kwargs)
self . layer3 = self . _make_layer ( block , drop_block = db_3 , * l l args[ 2 ] , * * l kwargs)
self . layer4 = self . _make_layer ( block , drop_block = db_4 , * l l args[ 3 ] , * * l kwargs)
self . layer1 = self . _make_layer ( block , * l ayer_ args[ 0 ] , * * l ayer_ kwargs)
self . layer2 = self . _make_layer ( block , * l ayer_ args[ 1 ] , * * l ayer_ kwargs)
self . layer3 = self . _make_layer ( block , drop_block = db_3 , * l ayer_ args[ 2 ] , * * l ayer_ kwargs)
self . layer4 = self . _make_layer ( block , drop_block = db_4 , * l ayer_ args[ 3 ] , * * l ayer_ kwargs)
# Head (Pooling and Classifier)
self . global_pool = SelectAdaptivePool2d ( pool_type = global_pool )
@ -411,31 +420,21 @@ class ResNet(nn.Module):
m . zero_init_last_bn ( )
def _make_layer ( self , block , planes , blocks , stride = 1 , dilation = 1 , reduce_first = 1 ,
use_se = False , use_eca = False , avg_down = False , down_kernel_size = 1 , * * kwargs ) :
norm_layer = kwargs . get ( ' norm_layer ' )
avg_down = False , down_kernel_size = 1 , * * kwargs ) :
downsample = None
down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size
first_dilation = 1 if dilation in ( 1 , 2 ) else 2
if stride != 1 or self . inplanes != planes * block . expansion :
downsample_padding = get_padding ( down_kernel_size , stride )
downsample_layers = [ ]
conv_stride = stride
if avg_down :
avg_stride = stride if dilation == 1 else 1
conv_stride = 1
downsample_layers = [ nn . AvgPool2d ( avg_stride , avg_stride , ceil_mode = True , count_include_pad = False ) ]
downsample_layers + = [
nn . Conv2d ( self . inplanes , planes * block . expansion , down_kernel_size ,
stride = conv_stride , padding = downsample_padding , bias = False ) ,
norm_layer ( planes * block . expansion ) ]
downsample = nn . Sequential ( * downsample_layers )
downsample_args = dict (
in_channels = self . inplanes , out_channels = planes * block . expansion , kernel_size = down_kernel_size ,
stride = stride , dilation = dilation , first_dilation = first_dilation , norm_layer = kwargs . get ( ' norm_layer ' ) )
downsample = downsample_avg ( * * downsample_args ) if avg_down else downsample_conv ( * * downsample_args )
first_dilation = 1 if dilation in ( 1 , 2 ) else 2
bkwargs = dict (
block_kwargs = dict (
cardinality = self . cardinality , base_width = self . base_width , reduce_first = reduce_first ,
dilation = dilation , use_se = use_se , * * kwargs )
layers = [ block ( self . inplanes , planes , stride , downsample , first_dilation = first_dilation , * * b kwargs) ]
dilation = dilation , * * kwargs )
layers = [ block ( self . inplanes , planes , stride , downsample , first_dilation = first_dilation , * * block_kwargs ) ]
self . inplanes = planes * block . expansion
layers + = [ block ( self . inplanes , planes , * * b kwargs) for _ in range ( 1 , blocks ) ]
layers + = [ block ( self . inplanes , planes , * * block_kwargs ) for _ in range ( 1 , blocks ) ]
return nn . Sequential ( * layers )
@ -936,9 +935,8 @@ def seresnext26d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
"""
default_cfg = default_cfgs [ ' seresnext26d_32x4d ' ]
model = ResNet (
Bottleneck , [ 2 , 2 , 2 , 2 ] , cardinality = 32 , base_width = 4 ,
stem_width = 32 , stem_type = ' deep ' , avg_down = True , use_se = True ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
Bottleneck , [ 2 , 2 , 2 , 2 ] , cardinality = 32 , base_width = 4 , stem_width = 32 , stem_type = ' deep ' , avg_down = True ,
num_classes = num_classes , in_chans = in_chans , block_args = dict ( attn_layer = ' se ' ) , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
@ -954,8 +952,8 @@ def seresnext26t_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
default_cfg = default_cfgs [ ' seresnext26t_32x4d ' ]
model = ResNet (
Bottleneck , [ 2 , 2 , 2 , 2 ] , cardinality = 32 , base_width = 4 ,
stem_width = 32 , stem_type = ' deep_tiered ' , avg_down = True , use_se = True ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
stem_width = 32 , stem_type = ' deep_tiered ' , avg_down = True ,
num_classes = num_classes , in_chans = in_chans , block_args = dict ( attn_layer = ' se ' ) , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
@ -971,25 +969,55 @@ def seresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs
default_cfg = default_cfgs [ ' seresnext26tn_32x4d ' ]
model = ResNet (
Bottleneck , [ 2 , 2 , 2 , 2 ] , cardinality = 32 , base_width = 4 ,
stem_width = 32 , stem_type = ' deep_tiered_narrow ' , avg_down = True , use_se = True ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
stem_width = 32 , stem_type = ' deep_tiered_narrow ' , avg_down = True ,
num_classes = num_classes , in_chans = in_chans , block_args = dict ( attn_layer = ' se ' ) , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
@register_model
def ecaresnext26tn_32x4d ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
""" Constructs a eca -ResNeXt-26-TN model.
""" Constructs a n ECA -ResNeXt-26-TN model.
This is technically a 28 layer ResNet , like a ' D ' bag - of - tricks model but with tiered 24 , 32 , 64 channels
in the deep stem . The channel number of the middle stem conv is narrower than the ' T ' variant .
this model replaces SE module with the ECA module
"""
default_cfg = default_cfgs [ ' ecaresnext26tn_32x4d ' ]
block_args = dict ( attn_layer = ' eca ' )
model = ResNet (
Bottleneck , [ 2 , 2 , 2 , 2 ] , cardinality = 32 , base_width = 4 ,
stem_width = 32 , stem_type = ' deep_tiered_narrow ' , avg_down = True , use_eca = True ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
stem_width = 32 , stem_type = ' deep_tiered_narrow ' , avg_down = True ,
num_classes = num_classes , in_chans = in_chans , block_args = block_args , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
@register_model
def ecaresnet18 ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
""" Constructs an ECA-ResNet-18 model.
"""
default_cfg = default_cfgs [ ' ecaresnet18 ' ]
block_args = dict ( attn_layer = ' eca ' )
model = ResNet (
BasicBlock , [ 2 , 2 , 2 , 2 ] , num_classes = num_classes , in_chans = in_chans , block_args = block_args , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
@register_model
def ecaresnet50 ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
""" Constructs an ECA-ResNet-50 model.
"""
default_cfg = default_cfgs [ ' ecaresnet50 ' ]
block_args = dict ( attn_layer = ' eca ' )
model = ResNet (
Bottleneck , [ 3 , 4 , 6 , 3 ] , num_classes = num_classes , in_chans = in_chans , block_args = block_args , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )