|
|
|
@ -1,11 +1,13 @@
|
|
|
|
|
"""Pytorch ResNet implementation w/ tweaks
|
|
|
|
|
This file is a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with
|
|
|
|
|
"""PyTorch ResNet
|
|
|
|
|
|
|
|
|
|
This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with
|
|
|
|
|
additional dropout and dynamic global avg/max pool.
|
|
|
|
|
|
|
|
|
|
ResNext additions added by Ross Wightman
|
|
|
|
|
ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants added by Ross Wightman
|
|
|
|
|
"""
|
|
|
|
|
import math
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
@ -33,6 +35,12 @@ default_cfgs = {
|
|
|
|
|
'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'),
|
|
|
|
|
'resnet34': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'),
|
|
|
|
|
'resnet26': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26-9aa10e23.pth',
|
|
|
|
|
interpolation='bicubic'),
|
|
|
|
|
'resnet26d': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth',
|
|
|
|
|
interpolation='bicubic'),
|
|
|
|
|
'resnet50': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth',
|
|
|
|
|
interpolation='bicubic'),
|
|
|
|
@ -45,6 +53,7 @@ default_cfgs = {
|
|
|
|
|
'resnext50_32x4d': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d-068914d1.pth',
|
|
|
|
|
interpolation='bicubic'),
|
|
|
|
|
'resnext50d_32x4d': _cfg(url=''),
|
|
|
|
|
'resnext101_32x4d': _cfg(url=''),
|
|
|
|
|
'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'),
|
|
|
|
|
'resnext101_64x4d': _cfg(url=''),
|
|
|
|
@ -56,30 +65,60 @@ default_cfgs = {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def conv3x3(in_planes, out_planes, stride=1):
|
|
|
|
|
"""3x3 convolution with padding"""
|
|
|
|
|
return nn.Conv2d(
|
|
|
|
|
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
|
|
|
|
def _get_padding(kernel_size, stride, dilation=1):
|
|
|
|
|
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
|
|
|
|
return padding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SEModule(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self, channels, reduction_channels):
|
|
|
|
|
super(SEModule, self).__init__()
|
|
|
|
|
#self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
|
|
|
self.fc1 = nn.Conv2d(
|
|
|
|
|
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
|
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
|
self.fc2 = nn.Conv2d(
|
|
|
|
|
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
|
|
|
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
module_input = x
|
|
|
|
|
#x = self.avg_pool(x)
|
|
|
|
|
x = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
|
|
|
|
|
x = self.fc1(x)
|
|
|
|
|
x = self.relu(x)
|
|
|
|
|
x = self.fc2(x)
|
|
|
|
|
x = self.sigmoid(x)
|
|
|
|
|
return module_input * x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BasicBlock(nn.Module):
|
|
|
|
|
expansion = 1
|
|
|
|
|
|
|
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
|
|
|
|
cardinality=1, base_width=64, drop_rate=0.0):
|
|
|
|
|
cardinality=1, base_width=64, use_se=False,
|
|
|
|
|
reduce_first=1, dilation=1, previous_dilation=1, norm_layer=nn.BatchNorm2d):
|
|
|
|
|
super(BasicBlock, self).__init__()
|
|
|
|
|
|
|
|
|
|
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
|
|
|
|
assert base_width == 64, 'BasicBlock doest not support changing base width'
|
|
|
|
|
|
|
|
|
|
self.conv1 = conv3x3(inplanes, planes, stride)
|
|
|
|
|
self.bn1 = nn.BatchNorm2d(planes)
|
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
|
self.conv2 = conv3x3(planes, planes)
|
|
|
|
|
self.bn2 = nn.BatchNorm2d(planes)
|
|
|
|
|
first_planes = planes // reduce_first
|
|
|
|
|
outplanes = planes * self.expansion
|
|
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(
|
|
|
|
|
inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation,
|
|
|
|
|
dilation=dilation, bias=False)
|
|
|
|
|
self.bn1 = norm_layer(first_planes)
|
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
|
self.conv2 = nn.Conv2d(
|
|
|
|
|
first_planes, outplanes, kernel_size=3, padding=previous_dilation,
|
|
|
|
|
dilation=previous_dilation, bias=False)
|
|
|
|
|
self.bn2 = norm_layer(outplanes)
|
|
|
|
|
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
|
|
|
|
self.downsample = downsample
|
|
|
|
|
self.stride = stride
|
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
|
self.dilation = dilation
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
residual = x
|
|
|
|
@ -87,13 +126,12 @@ class BasicBlock(nn.Module):
|
|
|
|
|
out = self.conv1(x)
|
|
|
|
|
out = self.bn1(out)
|
|
|
|
|
out = self.relu(out)
|
|
|
|
|
|
|
|
|
|
if self.drop_rate > 0.:
|
|
|
|
|
out = F.dropout(out, p=self.drop_rate, training=self.training)
|
|
|
|
|
|
|
|
|
|
out = self.conv2(out)
|
|
|
|
|
out = self.bn2(out)
|
|
|
|
|
|
|
|
|
|
if self.se is not None:
|
|
|
|
|
out = self.se(out)
|
|
|
|
|
|
|
|
|
|
if self.downsample is not None:
|
|
|
|
|
residual = self.downsample(x)
|
|
|
|
|
|
|
|
|
@ -107,22 +145,27 @@ class Bottleneck(nn.Module):
|
|
|
|
|
expansion = 4
|
|
|
|
|
|
|
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
|
|
|
|
cardinality=1, base_width=64, drop_rate=0.0):
|
|
|
|
|
cardinality=1, base_width=64, use_se=False,
|
|
|
|
|
reduce_first=1, dilation=1, previous_dilation=1, norm_layer=nn.BatchNorm2d):
|
|
|
|
|
super(Bottleneck, self).__init__()
|
|
|
|
|
|
|
|
|
|
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
|
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False)
|
|
|
|
|
self.bn1 = nn.BatchNorm2d(width)
|
|
|
|
|
self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
|
|
|
|
|
padding=1, groups=cardinality, bias=False)
|
|
|
|
|
self.bn2 = nn.BatchNorm2d(width)
|
|
|
|
|
self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
|
|
|
|
|
self.bn3 = nn.BatchNorm2d(planes * 4)
|
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
|
first_planes = width // reduce_first
|
|
|
|
|
outplanes = planes * self.expansion
|
|
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
|
|
|
|
|
self.bn1 = norm_layer(first_planes)
|
|
|
|
|
self.conv2 = nn.Conv2d(
|
|
|
|
|
first_planes, width, kernel_size=3, stride=stride,
|
|
|
|
|
padding=dilation, dilation=dilation, groups=cardinality, bias=False)
|
|
|
|
|
self.bn2 = norm_layer(width)
|
|
|
|
|
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
|
|
|
|
self.bn3 = norm_layer(outplanes)
|
|
|
|
|
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
|
self.downsample = downsample
|
|
|
|
|
self.stride = stride
|
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
|
self.dilation = dilation
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
residual = x
|
|
|
|
@ -131,9 +174,6 @@ class Bottleneck(nn.Module):
|
|
|
|
|
out = self.bn1(out)
|
|
|
|
|
out = self.relu(out)
|
|
|
|
|
|
|
|
|
|
if self.drop_rate > 0.:
|
|
|
|
|
out = F.dropout(out, p=self.drop_rate, training=self.training)
|
|
|
|
|
|
|
|
|
|
out = self.conv2(out)
|
|
|
|
|
out = self.bn2(out)
|
|
|
|
|
out = self.relu(out)
|
|
|
|
@ -141,6 +181,9 @@ class Bottleneck(nn.Module):
|
|
|
|
|
out = self.conv3(out)
|
|
|
|
|
out = self.bn3(out)
|
|
|
|
|
|
|
|
|
|
if self.se is not None:
|
|
|
|
|
out = self.se(out)
|
|
|
|
|
|
|
|
|
|
if self.downsample is not None:
|
|
|
|
|
residual = self.downsample(x)
|
|
|
|
|
|
|
|
|
@ -151,26 +194,110 @@ class Bottleneck(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResNet(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self, block, layers, num_classes=1000, in_chans=3,
|
|
|
|
|
cardinality=1, base_width=64,
|
|
|
|
|
drop_rate=0.0, block_drop_rate=0.0,
|
|
|
|
|
global_pool='avg'):
|
|
|
|
|
"""ResNet / ResNeXt / SE-ResNeXt / SE-Net
|
|
|
|
|
|
|
|
|
|
This class implements all variants of ResNet, ResNeXt, SE-ResNeXt, and SENet that
|
|
|
|
|
* have > 1 stride in the 3x3 conv layer of bottleneck
|
|
|
|
|
* have conv-bn-act ordering
|
|
|
|
|
|
|
|
|
|
This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s
|
|
|
|
|
variants included in the MXNet Gluon ResNetV1b model
|
|
|
|
|
|
|
|
|
|
ResNet variants:
|
|
|
|
|
* normal - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b
|
|
|
|
|
* c - 3 layer deep 3x3 stem, stem_width = 32
|
|
|
|
|
* d - 3 layer deep 3x3 stem, stem_width = 32, average pool in downsample
|
|
|
|
|
* e - 3 layer deep 3x3 stem, stem_width = 64, average pool in downsample *no pretrained weights available
|
|
|
|
|
* s - 3 layer deep 3x3 stem, stem_width = 64
|
|
|
|
|
|
|
|
|
|
ResNeXt
|
|
|
|
|
* normal - 7x7 stem, stem_width = 64, standard cardinality and base widths
|
|
|
|
|
* same c,d, e, s variants as ResNet can be enabled
|
|
|
|
|
|
|
|
|
|
SE-ResNeXt
|
|
|
|
|
* normal - 7x7 stem, stem_width = 64
|
|
|
|
|
* same c, d, e, s variants as ResNet can be enabled
|
|
|
|
|
|
|
|
|
|
SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64,
|
|
|
|
|
reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
block : Block
|
|
|
|
|
Class for the residual block. Options are BasicBlockGl, BottleneckGl.
|
|
|
|
|
layers : list of int
|
|
|
|
|
Numbers of layers in each block
|
|
|
|
|
num_classes : int, default 1000
|
|
|
|
|
Number of classification classes.
|
|
|
|
|
in_chans : int, default 3
|
|
|
|
|
Number of input (color) channels.
|
|
|
|
|
use_se : bool, default False
|
|
|
|
|
Enable Squeeze-Excitation module in blocks
|
|
|
|
|
cardinality : int, default 1
|
|
|
|
|
Number of convolution groups for 3x3 conv in Bottleneck.
|
|
|
|
|
base_width : int, default 64
|
|
|
|
|
Factor determining bottleneck channels. `planes * base_width / 64 * cardinality`
|
|
|
|
|
deep_stem : bool, default False
|
|
|
|
|
Whether to replace the 7x7 conv1 with 3 3x3 convolution layers.
|
|
|
|
|
stem_width : int, default 64
|
|
|
|
|
Number of channels in stem convolutions
|
|
|
|
|
block_reduce_first: int, default 1
|
|
|
|
|
Reduction factor for first convolution output width of residual blocks,
|
|
|
|
|
1 for all archs except senets, where 2
|
|
|
|
|
down_kernel_size: int, default 1
|
|
|
|
|
Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets
|
|
|
|
|
avg_down : bool, default False
|
|
|
|
|
Whether to use average pooling for projection skip connection between stages/downsample.
|
|
|
|
|
dilated : bool, default False
|
|
|
|
|
Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
|
|
|
|
|
typically used in Semantic Segmentation.
|
|
|
|
|
drop_rate : float, default 0.
|
|
|
|
|
Dropout probability before classifier, for training
|
|
|
|
|
global_pool : str, default 'avg'
|
|
|
|
|
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
|
|
|
|
|
cardinality=1, base_width=64, stem_width=64, deep_stem=False,
|
|
|
|
|
block_reduce_first=1, down_kernel_size=1, avg_down=False, dilated=False,
|
|
|
|
|
norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg'):
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.inplanes = 64
|
|
|
|
|
self.inplanes = stem_width * 2 if deep_stem else 64
|
|
|
|
|
self.cardinality = cardinality
|
|
|
|
|
self.base_width = base_width
|
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
|
self.expansion = block.expansion
|
|
|
|
|
self.dilated = dilated
|
|
|
|
|
super(ResNet, self).__init__()
|
|
|
|
|
self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
|
|
|
|
self.bn1 = nn.BatchNorm2d(64)
|
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
|
|
|
|
|
|
if deep_stem:
|
|
|
|
|
self.conv1 = nn.Sequential(*[
|
|
|
|
|
nn.Conv2d(in_chans, stem_width, 3, stride=2, padding=1, bias=False),
|
|
|
|
|
norm_layer(stem_width),
|
|
|
|
|
nn.ReLU(),
|
|
|
|
|
nn.Conv2d(stem_width, stem_width, 3, stride=1, padding=1, bias=False),
|
|
|
|
|
norm_layer(stem_width),
|
|
|
|
|
nn.ReLU(),
|
|
|
|
|
nn.Conv2d(stem_width, self.inplanes, 3, stride=1, padding=1, bias=False)])
|
|
|
|
|
else:
|
|
|
|
|
self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=7, stride=2, padding=3, bias=False)
|
|
|
|
|
self.bn1 = norm_layer(self.inplanes)
|
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
|
|
|
self.layer1 = self._make_layer(block, 64, layers[0], drop_rate=block_drop_rate)
|
|
|
|
|
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, drop_rate=block_drop_rate)
|
|
|
|
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, drop_rate=block_drop_rate)
|
|
|
|
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, drop_rate=block_drop_rate)
|
|
|
|
|
stride_3_4 = 1 if self.dilated else 2
|
|
|
|
|
dilation_3 = 2 if self.dilated else 1
|
|
|
|
|
dilation_4 = 4 if self.dilated else 1
|
|
|
|
|
self.layer1 = self._make_layer(
|
|
|
|
|
block, 64, layers[0], stride=1, reduce_first=block_reduce_first,
|
|
|
|
|
use_se=use_se, avg_down=avg_down, down_kernel_size=1, norm_layer=norm_layer)
|
|
|
|
|
self.layer2 = self._make_layer(
|
|
|
|
|
block, 128, layers[1], stride=2, reduce_first=block_reduce_first,
|
|
|
|
|
use_se=use_se, avg_down=avg_down, down_kernel_size=down_kernel_size, norm_layer=norm_layer)
|
|
|
|
|
self.layer3 = self._make_layer(
|
|
|
|
|
block, 256, layers[2], stride=stride_3_4, dilation=dilation_3, reduce_first=block_reduce_first,
|
|
|
|
|
use_se=use_se, avg_down=avg_down, down_kernel_size=down_kernel_size, norm_layer=norm_layer)
|
|
|
|
|
self.layer4 = self._make_layer(
|
|
|
|
|
block, 512, layers[3], stride=stride_3_4, dilation=dilation_4, reduce_first=block_reduce_first,
|
|
|
|
|
use_se=use_se, avg_down=avg_down, down_kernel_size=down_kernel_size, norm_layer=norm_layer)
|
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
|
self.num_features = 512 * block.expansion
|
|
|
|
|
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
|
|
|
@ -182,18 +309,34 @@ class ResNet(nn.Module):
|
|
|
|
|
nn.init.constant_(m.weight, 1.)
|
|
|
|
|
nn.init.constant_(m.bias, 0.)
|
|
|
|
|
|
|
|
|
|
def _make_layer(self, block, planes, blocks, stride=1, drop_rate=0.):
|
|
|
|
|
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
|
|
|
|
|
use_se=False, avg_down=False, down_kernel_size=1, norm_layer=nn.BatchNorm2d):
|
|
|
|
|
downsample = None
|
|
|
|
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
|
|
|
downsample = nn.Sequential(
|
|
|
|
|
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
|
|
|
|
nn.BatchNorm2d(planes * block.expansion),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
layers = [block(self.inplanes, planes, stride, downsample, self.cardinality, self.base_width, drop_rate)]
|
|
|
|
|
downsample_padding = _get_padding(down_kernel_size, stride)
|
|
|
|
|
downsample_layers = []
|
|
|
|
|
conv_stride = stride
|
|
|
|
|
if avg_down:
|
|
|
|
|
avg_stride = stride if dilation == 1 else 1
|
|
|
|
|
conv_stride = 1
|
|
|
|
|
downsample_layers = [nn.AvgPool2d(avg_stride, avg_stride, ceil_mode=True, count_include_pad=False)]
|
|
|
|
|
downsample_layers += [
|
|
|
|
|
nn.Conv2d(self.inplanes, planes * block.expansion, down_kernel_size,
|
|
|
|
|
stride=conv_stride, padding=downsample_padding, bias=False),
|
|
|
|
|
norm_layer(planes * block.expansion)]
|
|
|
|
|
downsample = nn.Sequential(*downsample_layers)
|
|
|
|
|
|
|
|
|
|
first_dilation = 1 if dilation in (1, 2) else 2
|
|
|
|
|
layers = [block(
|
|
|
|
|
self.inplanes, planes, stride, downsample,
|
|
|
|
|
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
|
|
|
|
|
use_se=use_se, dilation=first_dilation, previous_dilation=dilation, norm_layer=norm_layer)]
|
|
|
|
|
self.inplanes = planes * block.expansion
|
|
|
|
|
for i in range(1, blocks):
|
|
|
|
|
layers.append(block(self.inplanes, planes, cardinality=self.cardinality, base_width=self.base_width))
|
|
|
|
|
layers.append(block(
|
|
|
|
|
self.inplanes, planes,
|
|
|
|
|
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
|
|
|
|
|
use_se=use_se, dilation=dilation, previous_dilation=dilation, norm_layer=norm_layer))
|
|
|
|
|
|
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
|
|
|
@ -257,6 +400,33 @@ def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnet26(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
"""Constructs a ResNet-26 model.
|
|
|
|
|
"""
|
|
|
|
|
default_cfg = default_cfgs['resnet26']
|
|
|
|
|
model = ResNet(Bottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
|
if pretrained:
|
|
|
|
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
"""Constructs a ResNet-26 v1d model.
|
|
|
|
|
This is technically a 28 layer ResNet, sticking with 'd' modifier from Gluon for now.
|
|
|
|
|
"""
|
|
|
|
|
default_cfg = default_cfgs['resnet26d']
|
|
|
|
|
model = ResNet(
|
|
|
|
|
Bottleneck, [2, 2, 2, 2], stem_width=32, deep_stem=True, avg_down=True,
|
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
|
if pretrained:
|
|
|
|
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
"""Constructs a ResNet-50 model.
|
|
|
|
@ -362,6 +532,21 @@ def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnext50d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
"""Constructs a ResNeXt50-32x4d model.
|
|
|
|
|
"""
|
|
|
|
|
default_cfg = default_cfgs['resnext50d_32x4d']
|
|
|
|
|
model = ResNet(
|
|
|
|
|
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
|
|
|
|
stem_width=32, deep_stem=True, avg_down=True,
|
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
|
if pretrained:
|
|
|
|
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
"""Constructs a ResNeXt-101 32x4d model.
|
|
|
|
|