@ -205,14 +205,14 @@ class BasicBlock(nn.Module):
first_planes = planes / / reduce_first
first_planes = planes / / reduce_first
outplanes = planes * self . expansion
outplanes = planes * self . expansion
first_dilation = first_dilation or dilation
first_dilation = first_dilation or dilation
use_aa = aa_layer is not None
use_aa = aa_layer is not None and ( stride == 2 or first_dilation != dilation )
self . conv1 = nn . Conv2d (
self . conv1 = nn . Conv2d (
inplanes , first_planes , kernel_size = 3 , stride = 1 if use_aa else stride , padding = first_dilation ,
inplanes , first_planes , kernel_size = 3 , stride = 1 if use_aa else stride , padding = first_dilation ,
dilation = first_dilation , bias = False )
dilation = first_dilation , bias = False )
self . bn1 = norm_layer ( first_planes )
self . bn1 = norm_layer ( first_planes )
self . act1 = act_layer ( inplace = True )
self . act1 = act_layer ( inplace = True )
self . aa = aa_layer ( channels = first_planes ) if stride == 2 and use_aa else None
self . aa = aa_layer ( channels = first_planes , stride = stride ) if use_aa else None
self . conv2 = nn . Conv2d (
self . conv2 = nn . Conv2d (
first_planes , outplanes , kernel_size = 3 , padding = dilation , dilation = dilation , bias = False )
first_planes , outplanes , kernel_size = 3 , padding = dilation , dilation = dilation , bias = False )
@ -272,7 +272,7 @@ class Bottleneck(nn.Module):
first_planes = width / / reduce_first
first_planes = width / / reduce_first
outplanes = planes * self . expansion
outplanes = planes * self . expansion
first_dilation = first_dilation or dilation
first_dilation = first_dilation or dilation
use_aa = aa_layer is not None
use_aa = aa_layer is not None and ( stride == 2 or first_dilation != dilation )
self . conv1 = nn . Conv2d ( inplanes , first_planes , kernel_size = 1 , bias = False )
self . conv1 = nn . Conv2d ( inplanes , first_planes , kernel_size = 1 , bias = False )
self . bn1 = norm_layer ( first_planes )
self . bn1 = norm_layer ( first_planes )
@ -283,7 +283,7 @@ class Bottleneck(nn.Module):
padding = first_dilation , dilation = first_dilation , groups = cardinality , bias = False )
padding = first_dilation , dilation = first_dilation , groups = cardinality , bias = False )
self . bn2 = norm_layer ( width )
self . bn2 = norm_layer ( width )
self . act2 = act_layer ( inplace = True )
self . act2 = act_layer ( inplace = True )
self . aa = aa_layer ( channels = width ) if stride == 2 and use_aa else None
self . aa = aa_layer ( channels = width , stride = stride ) if use_aa else None
self . conv3 = nn . Conv2d ( width , outplanes , kernel_size = 1 , bias = False )
self . conv3 = nn . Conv2d ( width , outplanes , kernel_size = 1 , bias = False )
self . bn3 = norm_layer ( outplanes )
self . bn3 = norm_layer ( outplanes )
@ -336,14 +336,6 @@ class Bottleneck(nn.Module):
return x
return x
def setup_drop_block ( drop_block_rate = 0. ) :
return [
None ,
None ,
DropBlock2d ( drop_block_rate , 5 , 0.25 ) if drop_block_rate else None ,
DropBlock2d ( drop_block_rate , 3 , 1.00 ) if drop_block_rate else None ]
def downsample_conv (
def downsample_conv (
in_channels , out_channels , kernel_size , stride = 1 , dilation = 1 , first_dilation = None , norm_layer = None ) :
in_channels , out_channels , kernel_size , stride = 1 , dilation = 1 , first_dilation = None , norm_layer = None ) :
norm_layer = norm_layer or nn . BatchNorm2d
norm_layer = norm_layer or nn . BatchNorm2d
@ -375,6 +367,57 @@ def downsample_avg(
] )
] )
def drop_blocks ( drop_block_rate = 0. ) :
return [
None , None ,
DropBlock2d ( drop_block_rate , 5 , 0.25 ) if drop_block_rate else None ,
DropBlock2d ( drop_block_rate , 3 , 1.00 ) if drop_block_rate else None ]
def make_blocks (
block_fn , channels , block_repeats , inplanes , reduce_first = 1 , output_stride = 32 ,
down_kernel_size = 1 , avg_down = False , drop_block_rate = 0. , drop_path_rate = 0. , * * kwargs ) :
stages = [ ]
feature_info = [ ]
net_num_blocks = sum ( block_repeats )
net_block_idx = 0
net_stride = 4
dilation = prev_dilation = 1
for stage_idx , ( planes , num_blocks , db ) in enumerate ( zip ( channels , block_repeats , drop_blocks ( drop_block_rate ) ) ) :
stage_name = f ' layer { stage_idx + 1 } ' # never liked this name, but weight compat requires it
stride = 1 if stage_idx == 0 else 2
if net_stride > = output_stride :
dilation * = stride
stride = 1
else :
net_stride * = stride
downsample = None
if stride != 1 or inplanes != planes * block_fn . expansion :
down_kwargs = dict (
in_channels = inplanes , out_channels = planes * block_fn . expansion , kernel_size = down_kernel_size ,
stride = stride , dilation = dilation , first_dilation = prev_dilation , norm_layer = kwargs . get ( ' norm_layer ' ) )
downsample = downsample_avg ( * * down_kwargs ) if avg_down else downsample_conv ( * * down_kwargs )
block_kwargs = dict ( reduce_first = reduce_first , dilation = dilation , drop_block = db , * * kwargs )
blocks = [ ]
for block_idx in range ( num_blocks ) :
downsample = downsample if block_idx == 0 else None
stride = stride if block_idx == 0 else 1
block_dpr = drop_path_rate * net_block_idx / ( net_num_blocks - 1 ) # stochastic depth linear decay rule
blocks . append ( block_fn (
inplanes , planes , stride , downsample , first_dilation = prev_dilation ,
drop_path = DropPath ( block_dpr ) if block_dpr > 0. else None , * * block_kwargs ) )
prev_dilation = dilation
inplanes = planes * block_fn . expansion
net_block_idx + = 1
stages . append ( ( stage_name , nn . Sequential ( * blocks ) ) )
feature_info . append ( dict ( num_chs = inplanes , reduction = net_stride , module = stage_name ) )
return stages , feature_info
class ResNet ( nn . Module ) :
class ResNet ( nn . Module ) :
""" ResNet / ResNeXt / SE-ResNeXt / SE-Net
""" ResNet / ResNeXt / SE-ResNeXt / SE-Net
@ -448,21 +491,18 @@ class ResNet(nn.Module):
def __init__ ( self , block , layers , num_classes = 1000 , in_chans = 3 ,
def __init__ ( self , block , layers , num_classes = 1000 , in_chans = 3 ,
cardinality = 1 , base_width = 64 , stem_width = 64 , stem_type = ' ' ,
cardinality = 1 , base_width = 64 , stem_width = 64 , stem_type = ' ' ,
block_reduce_first= 1 , down_kernel_size = 1 , avg_down = False , output_stride = 32 ,
output_stride= 32 , block_reduce_first= 1 , down_kernel_size = 1 , avg_down = False ,
act_layer = nn . ReLU , norm_layer = nn . BatchNorm2d , aa_layer = None , drop_rate = 0.0 , drop_path_rate = 0. ,
act_layer = nn . ReLU , norm_layer = nn . BatchNorm2d , aa_layer = None , drop_rate = 0.0 , drop_path_rate = 0. ,
drop_block_rate = 0. , global_pool = ' avg ' , zero_init_last_bn = True , block_args = None ) :
drop_block_rate = 0. , global_pool = ' avg ' , zero_init_last_bn = True , block_args = None ) :
block_args = block_args or dict ( )
block_args = block_args or dict ( )
assert output_stride in ( 8 , 16 , 32 )
assert output_stride in ( 8 , 16 , 32 )
self . num_classes = num_classes
self . num_classes = num_classes
deep_stem = ' deep ' in stem_type
self . inplanes = stem_width * 2 if deep_stem else 64
self . cardinality = cardinality
self . base_width = base_width
self . drop_rate = drop_rate
self . drop_rate = drop_rate
self . expansion = block . expansion
super ( ResNet , self ) . __init__ ( )
super ( ResNet , self ) . __init__ ( )
# Stem
# Stem
deep_stem = ' deep ' in stem_type
inplanes = stem_width * 2 if deep_stem else 64
if deep_stem :
if deep_stem :
stem_chs_1 = stem_chs_2 = stem_width
stem_chs_1 = stem_chs_2 = stem_width
if ' tiered ' in stem_type :
if ' tiered ' in stem_type :
@ -475,43 +515,31 @@ class ResNet(nn.Module):
nn . Conv2d ( stem_chs_1 , stem_chs_2 , 3 , stride = 1 , padding = 1 , bias = False ) ,
nn . Conv2d ( stem_chs_1 , stem_chs_2 , 3 , stride = 1 , padding = 1 , bias = False ) ,
norm_layer ( stem_chs_2 ) ,
norm_layer ( stem_chs_2 ) ,
act_layer ( inplace = True ) ,
act_layer ( inplace = True ) ,
nn . Conv2d ( stem_chs_2 , self . inplanes , 3 , stride = 1 , padding = 1 , bias = False ) ] )
nn . Conv2d ( stem_chs_2 , inplanes , 3 , stride = 1 , padding = 1 , bias = False ) ] )
else :
else :
self . conv1 = nn . Conv2d ( in_chans , self . inplanes , kernel_size = 7 , stride = 2 , padding = 3 , bias = False )
self . conv1 = nn . Conv2d ( in_chans , inplanes , kernel_size = 7 , stride = 2 , padding = 3 , bias = False )
self . bn1 = norm_layer ( self . inplanes )
self . bn1 = norm_layer ( inplanes )
self . act1 = act_layer ( inplace = True )
self . act1 = act_layer ( inplace = True )
self . feature_info = [ dict ( num_chs = self . inplanes , reduction = 2 , module = ' act1 ' ) ]
self . feature_info = [ dict ( num_chs = inplanes , reduction = 2 , module = ' act1 ' ) ]
# Stem Pooling
# Stem Pooling
if aa_layer is not None :
if aa_layer is not None :
self . maxpool = nn . Sequential ( * [
self . maxpool = nn . Sequential ( * [
nn . MaxPool2d ( kernel_size = 3 , stride = 1 , padding = 1 ) ,
nn . MaxPool2d ( kernel_size = 3 , stride = 1 , padding = 1 ) ,
aa_layer ( channels = self . inplanes , stride = 2 )
aa_layer ( channels = inplanes , stride = 2 ) ] )
] )
else :
else :
self . maxpool = nn . MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 1 )
self . maxpool = nn . MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 1 )
# Feature Blocks
# Feature Blocks
channels = [ 64 , 128 , 256 , 512 ]
channels = [ 64 , 128 , 256 , 512 ]
dp = DropPath ( drop_path_rate ) if drop_path_rate else None
stage_modules , stage_feature_info = make_blocks (
db = setup_drop_block ( drop_block_rate )
block , channels , layers , inplanes , cardinality = cardinality , base_width = base_width ,
layer_kwargs = dict (
output_stride = output_stride , reduce_first = block_reduce_first , avg_down = avg_down ,
reduce_first = block_reduce_first , act_layer = act_layer , norm_layer = norm_layer , aa_layer = aa_layer ,
down_kernel_size = down_kernel_size , act_layer = act_layer , norm_layer = norm_layer , aa_layer = aa_layer ,
avg_down = avg_down , down_kernel_size = down_kernel_size , drop_path = dp , * * block_args )
drop_block_rate = drop_block_rate , drop_path_rate = drop_path_rate , * * block_args )
total_stride = 4
for stage in stage_modules :
dilation = 1
self . add_module ( * stage ) # layer1, layer2, etc
for i in range ( 4 ) :
self . feature_info . extend ( stage_feature_info )
layer_name = f ' layer { i + 1 } '
stride = 2 if i > 0 else 1
if total_stride > = output_stride :
dilation * = stride
stride = 1
else :
total_stride * = stride
self . add_module ( layer_name , self . _make_layer (
block , channels [ i ] , layers [ i ] , stride , dilation , drop_block = db [ i ] , * * layer_kwargs ) )
self . feature_info . append ( dict (
num_chs = self . inplanes , reduction = total_stride , module = layer_name ) )
# Head (Pooling and Classifier)
# Head (Pooling and Classifier)
self . global_pool = SelectAdaptivePool2d ( pool_type = global_pool )
self . global_pool = SelectAdaptivePool2d ( pool_type = global_pool )
@ -529,25 +557,6 @@ class ResNet(nn.Module):
if hasattr ( m , ' zero_init_last_bn ' ) :
if hasattr ( m , ' zero_init_last_bn ' ) :
m . zero_init_last_bn ( )
m . zero_init_last_bn ( )
def _make_layer ( self , block , planes , blocks , stride = 1 , dilation = 1 , reduce_first = 1 ,
avg_down = False , down_kernel_size = 1 , * * kwargs ) :
downsample = None
first_dilation = 1 if dilation in ( 1 , 2 ) else 2
if stride != 1 or self . inplanes != planes * block . expansion :
downsample_args = dict (
in_channels = self . inplanes , out_channels = planes * block . expansion , kernel_size = down_kernel_size ,
stride = stride , dilation = dilation , first_dilation = first_dilation , norm_layer = kwargs . get ( ' norm_layer ' ) )
downsample = downsample_avg ( * * downsample_args ) if avg_down else downsample_conv ( * * downsample_args )
block_kwargs = dict (
cardinality = self . cardinality , base_width = self . base_width , reduce_first = reduce_first ,
dilation = dilation , * * kwargs )
layers = [ block ( self . inplanes , planes , stride , downsample , first_dilation = first_dilation , * * block_kwargs ) ]
self . inplanes = planes * block . expansion
layers + = [ block ( self . inplanes , planes , * * block_kwargs ) for _ in range ( 1 , blocks ) ]
return nn . Sequential ( * layers )
def get_classifier ( self ) :
def get_classifier ( self ) :
return self . fc
return self . fc