|
|
|
@ -315,9 +315,10 @@ def create_aa(aa_layer, channels, stride=2, enable=True):
|
|
|
|
|
class BasicBlock(nn.Module):
|
|
|
|
|
expansion = 1
|
|
|
|
|
|
|
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
|
|
|
|
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
|
|
|
|
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
|
|
|
|
|
def __init__(
|
|
|
|
|
self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
|
|
|
|
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
|
|
|
|
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
|
|
|
|
|
super(BasicBlock, self).__init__()
|
|
|
|
|
|
|
|
|
|
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
|
|
|
@ -379,9 +380,10 @@ class BasicBlock(nn.Module):
|
|
|
|
|
class Bottleneck(nn.Module):
|
|
|
|
|
expansion = 4
|
|
|
|
|
|
|
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
|
|
|
|
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
|
|
|
|
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
|
|
|
|
|
def __init__(
|
|
|
|
|
self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
|
|
|
|
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
|
|
|
|
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
|
|
|
|
|
super(Bottleneck, self).__init__()
|
|
|
|
|
|
|
|
|
|
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
|
|
|
@ -561,48 +563,35 @@ class ResNet(nn.Module):
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
block : Block
|
|
|
|
|
Class for the residual block. Options are BasicBlockGl, BottleneckGl.
|
|
|
|
|
layers : list of int
|
|
|
|
|
Numbers of layers in each block
|
|
|
|
|
num_classes : int, default 1000
|
|
|
|
|
Number of classification classes.
|
|
|
|
|
in_chans : int, default 3
|
|
|
|
|
Number of input (color) channels.
|
|
|
|
|
cardinality : int, default 1
|
|
|
|
|
Number of convolution groups for 3x3 conv in Bottleneck.
|
|
|
|
|
base_width : int, default 64
|
|
|
|
|
Factor determining bottleneck channels. `planes * base_width / 64 * cardinality`
|
|
|
|
|
stem_width : int, default 64
|
|
|
|
|
Number of channels in stem convolutions
|
|
|
|
|
block : Block, class for the residual block. Options are BasicBlockGl, BottleneckGl.
|
|
|
|
|
layers : list of int, number of layers in each block
|
|
|
|
|
num_classes : int, default 1000, number of classification classes.
|
|
|
|
|
in_chans : int, default 3, number of input (color) channels.
|
|
|
|
|
output_stride : int, default 32, output stride of the network, 32, 16, or 8.
|
|
|
|
|
global_pool : str, Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
|
|
|
|
|
cardinality : int, default 1, number of convolution groups for 3x3 conv in Bottleneck.
|
|
|
|
|
base_width : int, default 64, factor determining bottleneck channels. `planes * base_width / 64 * cardinality`
|
|
|
|
|
stem_width : int, default 64, number of channels in stem convolutions
|
|
|
|
|
stem_type : str, default ''
|
|
|
|
|
The type of stem:
|
|
|
|
|
* '', default - a single 7x7 conv with a width of stem_width
|
|
|
|
|
* 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
|
|
|
|
|
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
|
|
|
|
|
block_reduce_first: int, default 1
|
|
|
|
|
Reduction factor for first convolution output width of residual blocks,
|
|
|
|
|
1 for all archs except senets, where 2
|
|
|
|
|
down_kernel_size: int, default 1
|
|
|
|
|
Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets
|
|
|
|
|
avg_down : bool, default False
|
|
|
|
|
Whether to use average pooling for projection skip connection between stages/downsample.
|
|
|
|
|
output_stride : int, default 32
|
|
|
|
|
Set the output stride of the network, 32, 16, or 8. Typically used in segmentation.
|
|
|
|
|
block_reduce_first : int, default 1
|
|
|
|
|
Reduction factor for first convolution output width of residual blocks, 1 for all archs except senets, where 2
|
|
|
|
|
down_kernel_size : int, default 1, kernel size of residual block downsample path, 1x1 for most, 3x3 for senets
|
|
|
|
|
avg_down : bool, default False, use average pooling for projection skip connection between stages/downsample.
|
|
|
|
|
act_layer : nn.Module, activation layer
|
|
|
|
|
norm_layer : nn.Module, normalization layer
|
|
|
|
|
aa_layer : nn.Module, anti-aliasing layer
|
|
|
|
|
drop_rate : float, default 0.
|
|
|
|
|
Dropout probability before classifier, for training
|
|
|
|
|
global_pool : str, default 'avg'
|
|
|
|
|
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
|
|
|
|
|
drop_rate : float, default 0. Dropout probability before classifier, for training
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, block, layers, num_classes=1000, in_chans=3,
|
|
|
|
|
cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False,
|
|
|
|
|
output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False,
|
|
|
|
|
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
|
|
|
|
|
drop_block_rate=0., global_pool='avg', zero_init_last=True, block_args=None):
|
|
|
|
|
def __init__(
|
|
|
|
|
self, block, layers, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg',
|
|
|
|
|
cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False, block_reduce_first=1,
|
|
|
|
|
down_kernel_size=1, avg_down=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None,
|
|
|
|
|
drop_rate=0.0, drop_path_rate=0., drop_block_rate=0., zero_init_last=True, block_args=None):
|
|
|
|
|
super(ResNet, self).__init__()
|
|
|
|
|
block_args = block_args or dict()
|
|
|
|
|
assert output_stride in (8, 16, 32)
|
|
|
|
@ -712,12 +701,15 @@ class ResNet(nn.Module):
|
|
|
|
|
x = self.layer4(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.forward_features(x)
|
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
|
x = self.global_pool(x)
|
|
|
|
|
if self.drop_rate:
|
|
|
|
|
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
|
|
|
|
|
x = self.fc(x)
|
|
|
|
|
return x if pre_logits else self.fc(x)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.forward_features(x)
|
|
|
|
|
x = self.forward_head(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|