@ -91,6 +91,12 @@ default_cfgs = {
url = ' https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth ' ) ,
' swsl_resnext101_32x16d ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth ' ) ,
' seresnext26d_32x4d ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth ' ,
interpolation = ' bicubic ' ) ,
' seresnext26t_32x4d ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26t_32x4d-361bc1c4.pth ' ,
interpolation = ' bicubic ' ) ,
}
@ -231,10 +237,11 @@ class ResNet(nn.Module):
ResNet variants :
* normal , b - 7 x7 stem , stem_width = 64 , same as torchvision ResNet , NVIDIA ResNet ' v1.5 ' , Gluon v1b
* c - 3 layer deep 3 x3 stem , stem_width = 32
* d - 3 layer deep 3 x3 stem , stem_width = 32 , average pool in downsample
* e - 3 layer deep 3 x3 stem , stem_width = 64 , average pool in downsample
* s - 3 layer deep 3 x3 stem , stem_width = 64
* c - 3 layer deep 3 x3 stem , stem_width = 32 ( 32 , 32 , 64 )
* d - 3 layer deep 3 x3 stem , stem_width = 32 ( 32 , 32 , 64 ) , average pool in downsample
* e - 3 layer deep 3 x3 stem , stem_width = 64 ( 64 , 64 , 128 ) , average pool in downsample
* s - 3 layer deep 3 x3 stem , stem_width = 64 ( 64 , 64 , 128 )
* t - 3 layer deep 3 x3 stem , stem width = 32 ( 24 , 48 , 64 ) , average pool in downsample
ResNeXt
* normal - 7 x7 stem , stem_width = 64 , standard cardinality and base widths
@ -263,10 +270,13 @@ class ResNet(nn.Module):
Number of convolution groups for 3 x3 conv in Bottleneck .
base_width : int , default 64
Factor determining bottleneck channels . ` planes * base_width / 64 * cardinality `
deep_stem : bool , default False
Whether to replace the 7 x7 conv1 with 3 3 x3 convolution layers .
stem_width : int , default 64
Number of channels in stem convolutions
stem_type : str , default ' '
The type of stem :
* ' ' , default - a single 7 x7 conv with a width of stem_width
* ' deep ' - three 3 x3 convolution layers of widths stem_width , stem_width , stem_width * 2
* ' deep_tiered ' - three 3 x3 conv layers of widths stem_width / / 4 * 3 , stem_width / / 4 * 6 , stem_width * 2
block_reduce_first : int , default 1
Reduction factor for first convolution output width of residual blocks ,
1 for all archs except senets , where 2
@ -283,12 +293,13 @@ class ResNet(nn.Module):
Global pooling type . One of ' avg ' , ' max ' , ' avgmax ' , ' catavgmax '
"""
def __init__ ( self , block , layers , num_classes = 1000 , in_chans = 3 , use_se = False ,
cardinality = 1 , base_width = 64 , stem_width = 64 , deep_stem= False ,
cardinality = 1 , base_width = 64 , stem_width = 64 , stem_type= ' ' ,
block_reduce_first = 1 , down_kernel_size = 1 , avg_down = False , dilated = False ,
norm_layer = nn . BatchNorm2d , drop_rate = 0.0 , global_pool = ' avg ' ,
zero_init_last_bn = True , block_args = None ) :
block_args = block_args or dict ( )
self . num_classes = num_classes
deep_stem = ' deep ' in stem_type
self . inplanes = stem_width * 2 if deep_stem else 64
self . cardinality = cardinality
self . base_width = base_width
@ -298,16 +309,20 @@ class ResNet(nn.Module):
super ( ResNet , self ) . __init__ ( )
if deep_stem :
stem_chs_1 = stem_chs_2 = stem_width
if ' tiered ' in stem_type :
stem_chs_1 = 3 * ( stem_width / / 4 )
stem_chs_2 = 6 * ( stem_width / / 4 )
self . conv1 = nn . Sequential ( * [
nn . Conv2d ( in_chans , stem_width , 3 , stride = 2 , padding = 1 , bias = False ) ,
norm_layer ( stem_width ) ,
nn . Conv2d ( in_chans , stem_ chs_1 , 3 , stride = 2 , padding = 1 , bias = False ) ,
norm_layer ( stem_ chs_1 ) ,
nn . ReLU ( inplace = True ) ,
nn . Conv2d ( stem_ width, stem_width , 3 , stride = 1 , padding = 1 , bias = False ) ,
norm_layer ( stem_ width ) ,
nn . Conv2d ( stem_ chs_1, stem_chs_2 , 3 , stride = 1 , padding = 1 , bias = False ) ,
norm_layer ( stem_ chs_2 ) ,
nn . ReLU ( inplace = True ) ,
nn . Conv2d ( stem_ width , self . inplanes , 3 , stride = 1 , padding = 1 , bias = False ) ] )
nn . Conv2d ( stem_ chs_2 , self . inplanes , 3 , stride = 1 , padding = 1 , bias = False ) ] )
else :
self . conv1 = nn . Conv2d ( in_chans , stem_width , kernel_size = 7 , stride = 2 , padding = 3 , bias = False )
self . conv1 = nn . Conv2d ( in_chans , self . inplanes , kernel_size = 7 , stride = 2 , padding = 3 , bias = False )
self . bn1 = norm_layer ( self . inplanes )
self . relu = nn . ReLU ( inplace = True )
self . maxpool = nn . MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 1 )
@ -324,7 +339,7 @@ class ResNet(nn.Module):
self . num_features = 512 * block . expansion
self . fc = nn . Linear ( self . num_features * self . global_pool . feat_mult ( ) , num_classes )
last_bn_name = ' bn3 ' if ' Bottle neck ' in block . __name__ else ' bn2 '
last_bn_name = ' bn3 ' if ' Bottle ' in block . __name__ else ' bn2 '
for n , m in self . named_modules ( ) :
if isinstance ( m , nn . Conv2d ) :
nn . init . kaiming_normal_ ( m . weight , mode = ' fan_out ' , nonlinearity = ' relu ' )
@ -440,7 +455,7 @@ def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""
default_cfg = default_cfgs [ ' resnet26d ' ]
model = ResNet (
Bottleneck , [ 2 , 2 , 2 , 2 ] , stem_width = 32 , deep_stem= True , avg_down = True ,
Bottleneck , [ 2 , 2 , 2 , 2 ] , stem_width = 32 , stem_type= ' deep ' , avg_down = True ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
@ -466,7 +481,7 @@ def resnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""
default_cfg = default_cfgs [ ' resnet50d ' ]
model = ResNet (
Bottleneck , [ 3 , 4 , 6 , 3 ] , stem_width = 32 , deep_stem= True , avg_down = True ,
Bottleneck , [ 3 , 4 , 6 , 3 ] , stem_width = 32 , stem_type= ' deep ' , avg_down = True ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
@ -574,7 +589,7 @@ def resnext50d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
default_cfg = default_cfgs [ ' resnext50d_32x4d ' ]
model = ResNet (
Bottleneck , [ 3 , 4 , 6 , 3 ] , cardinality = 32 , base_width = 4 ,
stem_width = 32 , deep_stem= True , avg_down = True ,
stem_width = 32 , stem_type= ' deep ' , avg_down = True ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
@ -854,3 +869,34 @@ def swsl_resnext101_32x16d(pretrained=True, **kwargs):
if pretrained :
load_pretrained ( model , num_classes = kwargs . get ( ' num_classes ' , 0 ) , in_chans = kwargs . get ( ' in_chans ' , 3 ) )
return model
@register_model
def seresnext26d_32x4d ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
""" Constructs a ResNet-26 v1d model.
This is technically a 28 layer ResNet , sticking with ' d ' modifier from Gluon for now .
"""
default_cfg = default_cfgs [ ' seresnext26d_32x4d ' ]
model = ResNet (
Bottleneck , [ 2 , 2 , 2 , 2 ] , cardinality = 32 , base_width = 4 ,
stem_width = 32 , stem_type = ' deep ' , avg_down = True , use_se = True ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model
@register_model
def seresnext26t_32x4d ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
""" Constructs a ResNet-26 v1d model.
"""
default_cfg = default_cfgs [ ' seresnext26t_32x4d ' ]
model = ResNet (
Bottleneck , [ 2 , 2 , 2 , 2 ] , cardinality = 32 , base_width = 4 ,
stem_width = 32 , stem_type = ' deep_tiered ' , avg_down = True , use_se = True ,
num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model